From c335f6ef7751f379aac9f48f6c26cafc90f52103 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 16 Dec 2025 01:01:18 +0000 Subject: [PATCH 01/21] train with only layer distillation losses --- fast_llm/layers/language_model/head.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b1d0c2ac..db768ca1 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -409,14 +409,23 @@ def _logits_cross_entropy_forward_backward( else: distillation_loss, distillation_grad = None, None - # TODO: de-allocate earlier. - del logits - # TODO: Accumulate grads in-place to reduce memory and compute overhead. grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) # TODO: Return individual losses? loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + + # When using only activation distillation, loss and grad are None. + # Create zero tensors to allow activation distillation gradients to flow through. + if loss is None: + loss = torch.zeros(1, device=input_.device, dtype=input_.dtype, requires_grad=True) + if grad is None: + # Zero gradient means no loss at the head, but activation distillation gradients + grad = torch.zeros_like(logits) + + # TODO: de-allocate earlier. + del logits + if self.training and losses is not None: if dpo_loss is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) @@ -502,11 +511,12 @@ def _format_name(name: str) -> str: return name.replace("_", " ") -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: +def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor | None: tensors = [tensor for tensor in tensors if tensor is not None] if len(tensors) > 1: return sum(tensors) elif len(tensors) == 1: return tensors[0] else: - raise RuntimeError() + # All tensors are None - this is valid when using only activation distillation + return None From e06a4b2ca02b22dc56e798aabf0b8c30fe280417 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 16 Dec 2025 14:15:45 +0000 Subject: [PATCH 02/21] unscaled loss llogging + training with distillation loss factor = 0 --- fast_llm/layers/language_model/head.py | 53 +++++++++++++++++++------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index db768ca1..733311d3 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -370,11 +370,13 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.labels, ) + if self.training and losses is not None: + losses[self._ce_loss_name_unscaled].append(lm_loss.detach()) lm_loss = lm_loss * self._config.language_model_loss_factor else: lm_loss, lm_grad = None, None - if distillation_target is not None and self._config.distillation_loss_factor > 0.0: + if distillation_target is not None: if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), @@ -405,9 +407,9 @@ def _logits_cross_entropy_forward_backward( raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) + if self.training and losses is not None: # we keep track of unscaled losses for model comparison purposes + losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) distillation_loss = distillation_loss * self._config.distillation_loss_factor - else: - distillation_loss, distillation_grad = None, None # TODO: Accumulate grads in-place to reduce memory and compute overhead. grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) @@ -415,14 +417,6 @@ def _logits_cross_entropy_forward_backward( # TODO: Return individual losses? loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) - # When using only activation distillation, loss and grad are None. - # Create zero tensors to allow activation distillation gradients to flow through. - if loss is None: - loss = torch.zeros(1, device=input_.device, dtype=input_.dtype, requires_grad=True) - if grad is None: - # Zero gradient means no loss at the head, but activation distillation gradients - grad = torch.zeros_like(logits) - # TODO: de-allocate earlier. del logits @@ -443,6 +437,13 @@ def _loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name + @functools.cached_property + def _ce_loss_name_unscaled(self) -> str: + name = "language_model_loss_unscaled" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + @functools.cached_property def _z_loss_name(self) -> str: name = "z_loss" @@ -471,8 +472,24 @@ def _distillation_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name + @functools.cached_property + def _distillation_loss_name_unscaled(self) -> str: + name = "distillation_loss_unscaled" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] + if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: + # unscaled CE loss (NTP) + loss_defs = [ + LossDef( + name=self._ce_loss_name_unscaled, + formatted_name=_format_name(self._ce_loss_name_unscaled), + count=count, + ) + ] if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) @@ -490,6 +507,15 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) + # unscaled distillation loss for comparison purposes + loss_defs.append( + LossDef( + name=self._distillation_loss_name_unscaled, + formatted_name=_format_name(self._distillation_loss_name_unscaled), + count=count, + ) + ) + # if we mix distillation loss and CE loss for NTP, we want to log both if self._config.language_model_loss_factor > 0.0: loss_defs.append( LossDef( @@ -511,12 +537,11 @@ def _format_name(name: str) -> str: return name.replace("_", " ") -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor | None: +def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: tensors = [tensor for tensor in tensors if tensor is not None] if len(tensors) > 1: return sum(tensors) elif len(tensors) == 1: return tensors[0] else: - # All tensors are None - this is valid when using only activation distillation - return None + raise RuntimeError() From 179ae25e9db3ecda3c75762288abe824c31e65fd Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 17 Dec 2025 21:07:54 +0000 Subject: [PATCH 03/21] make logging more explicit --- fast_llm/layers/language_model/config.py | 12 ++ fast_llm/layers/language_model/head.py | 217 +++++++++++++++-------- 2 files changed, 153 insertions(+), 76 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 53dac289..13c6d87e 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -168,11 +168,21 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Factor to scale the language modeling loss by when using distillation.", hint=FieldHint.feature, ) + track_language_model_loss: bool = Field( + default=False, + desc="Track the unscaled language modeling loss for logging purposes. Will always do if language_model_loss_factor > 0.", + hint=FieldHint.feature, + ) distillation_loss_factor: float = Field( default=1.0, desc="Factor to scale the distillation loss by when using distillation.", hint=FieldHint.feature, ) + track_distillation_loss: bool = Field( + default=False, + desc="Track the unscaled distillation loss for logging purposes. Will always do if distillation_loss_factor > 0.", + hint=FieldHint.feature, + ) logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -243,6 +253,8 @@ def _validate(self) -> None: else: self.language_model_loss_factor = 0.0 super()._validate() + if self.distillation_model is None: + Assert.is_(self.track_distillation_loss, False) assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both @property diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 733311d3..e785c09e 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -113,6 +113,12 @@ def __init__( peft=self._peft, ) + self._compute_lm_loss = self.config.language_model_loss_factor > 0.0 or self.config.track_language_model_loss + self._compute_dpo_loss = self._config.enable_dpo + self._compute_distillation_loss = self._config.distillation_model is not None and ( + self._config.distillation_loss_factor > 0.0 or self._config.track_distillation_loss + ) + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: @@ -137,8 +143,6 @@ def forward( # TODO: Drop autograd entirely. # TODO: Skip cross-entropy backward if not needed. language_model_loss = self._forward(input_, kwargs, losses) - if losses is not None and language_model_loss is not None: - losses[self._loss_name].append(language_model_loss.detach()) # TODO: Return the model output when needed. if self._is_last_head: # Last head should return the loss for backward. @@ -205,25 +209,22 @@ def _get_targets( if loss_mask is not None: loss_mask = loss_mask.flatten() - if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice( - self._prediction_distance, self._prediction_distance + lm_target_sequence_length - ) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - else: - lm_target = None + lm_target = kwargs.get(LanguageModelKwargs.labels) + if lm_target is not None: + # MTP: Shift the labels + lm_target_sequence_length = ( + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads + ) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + lm_target_slice = slice( + self._prediction_distance, self._prediction_distance + lm_target_sequence_length + ) + lm_target = ( + lm_target[lm_target_slice] + if kwargs[LanguageModelKwargs.sequence_first] + else lm_target[:, lm_target_slice] + ).flatten() targets = (dpo_target, lm_target, distillation_target, loss_mask) if self._sequence_parallel_logits: @@ -246,7 +247,7 @@ def _logits_cross_entropy_forward_backward_split( losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: if self._config.cross_entropy_splits is None or targets is None: - loss, logit_input_grad = self._logits_cross_entropy_forward_backward( + loss, logit_input_grad = self._logits_loss_forward_backward( input_, targets, weight, grad_output, kwargs, losses ) if targets is None: @@ -279,7 +280,7 @@ def _logits_cross_entropy_forward_backward_split( for tensor in [logit_input, *targets, logit_input_grad] ] for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): - loss_, grad_ = self._logits_cross_entropy_forward_backward( + loss_, grad_ = self._logits_loss_forward_backward( logit_input_, targets_, weight, @@ -301,7 +302,7 @@ def _logits_cross_entropy_forward_backward_split( all_reduce(loss, group=self._parallel_dim.group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None - def _logits_cross_entropy_forward_backward( + def _logits_loss_forward_backward( self, input_: torch.Tensor, targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None], @@ -359,7 +360,7 @@ def _logits_cross_entropy_forward_backward( else: dpo_loss, dpo_grad = None, None - if lm_target is not None: + if lm_target is not None and self._compute_lm_loss: lm_loss, lm_grad = cross_entropy_forward_backward( logits.flatten(0, -2), lm_target, @@ -370,13 +371,10 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.labels, ) - if self.training and losses is not None: - losses[self._ce_loss_name_unscaled].append(lm_loss.detach()) - lm_loss = lm_loss * self._config.language_model_loss_factor else: lm_loss, lm_grad = None, None - if distillation_target is not None: + if distillation_target is not None and self._compute_distillation_loss: if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), @@ -407,39 +405,121 @@ def _logits_cross_entropy_forward_backward( raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) - if self.training and losses is not None: # we keep track of unscaled losses for model comparison purposes - losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) - distillation_loss = distillation_loss * self._config.distillation_loss_factor - - # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) - - # TODO: Return individual losses? - loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + else: + distillation_loss, distillation_grad = None, None # TODO: de-allocate earlier. del logits + loss, grad = self._post_process_loss_and_grad( + dpo_loss, + dpo_grad, + lm_loss, + lm_grad, + distillation_loss, + distillation_grad, + losses, + loss_mask, + kwargs, + ) + + return loss, output_parallel_linear_backward(grad, context) if self.training else None - if self.training and losses is not None: - if dpo_loss is not None: + def _post_process_loss_and_grad( + self, + dpo_loss: torch.Tensor | None, + dpo_grad: torch.Tensor | None, + lm_loss: torch.Tensor | None, + lm_grad: torch.Tensor | None, + distillation_loss: torch.Tensor | None, + distillation_grad: torch.Tensor | None, + losses: dict | None, + loss_mask: torch.Tensor | None, + kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + If loss is provided (i.e. not None) it will be logged in scaled and unscaled version. The total loss is also logged. + + Arguments: + - Losses: unscaled losses from different components (DPO, LM CE, Distillation) + - Grads: gradients of the losses w.r.t. logits from different components, already scaled by loss factors. + """ + # Extremely explicit but easier to follow. + ############ + if dpo_loss is not None: + if self.training and losses is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) - if self._config.distillation_model is not None and distillation_loss is not None: + else: + Assert.is_(dpo_grad, None) + + if lm_loss is not None: + if self.training and losses is not None: + losses[self._lm_loss_name_unscaled].append(lm_loss.detach()) + lm_loss = lm_loss * self._config.language_model_loss_factor # does not need scaling by loss_scalor_df + if self.training and losses is not None: + losses[self._lm_loss_name].append(lm_loss.detach()) + else: + Assert.is_(lm_grad, None) + + if distillation_loss is not None: + # We need to scale the loss by (valid_tokens * num_micro_batches) / total_valid_tokens to correctly average the loss over micro-batches. + # The runner averages losses by dividing by num_micro_batches, so we need to account for that. + # Note: for grads this scaling is already in the 'grad_output' + total_valid_tokens = kwargs.get( + LanguageModelKwargs.total_valid_tokens + ) # number of not masked tokens across all micro-batches. + num_micro_batches = kwargs.get("num_micro_batches", 1) + + if loss_mask is None or total_valid_tokens is None: + loss_scalor_df = 1 + else: + valid_tokens = loss_mask.sum() + # Scale by (valid_tokens * num_micro_batches) / total_valid_tokens + # This accounts for the runner dividing by num_micro_batches + loss_scalor_df = (valid_tokens * num_micro_batches) / total_valid_tokens + distillation_loss = distillation_loss * loss_scalor_df + if self.training and losses is not None: + losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) + distillation_loss = distillation_loss * self._config.distillation_loss_factor + if self.training and losses is not None: losses[self._distillation_loss_name].append(distillation_loss.detach()) - if self._config.distillation_model is not None and lm_loss is not None: - losses[self._distillation_language_model_loss_name].append(lm_loss.detach()) + else: + Assert.is_(distillation_grad, None) - return loss, output_parallel_linear_backward(grad, context) if self.training else None + ############ + # TODO: Accumulate grads in-place to reduce memory and compute overhead. + grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) + total_loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + if losses is not None and total_loss is not None: + losses[self._total_loss_name].append(total_loss.detach()) + + return total_loss, grad @functools.cached_property - def _loss_name(self) -> str: - name = "language_model_loss" + def _total_loss_name(self) -> str: + """ + Combined total scaled loss used for training. + """ + name = "lm_head_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _ce_loss_name_unscaled(self) -> str: - name = "language_model_loss_unscaled" + def _lm_loss_name_unscaled(self) -> str: + """ + Unscaled language model cross-entropy loss. + """ + name = "lm_loss_unscaled" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _lm_loss_name(self) -> str: + """ + Scaled language model cross-entropy loss. + """ + name = "lm_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @@ -459,8 +539,8 @@ def _dpo_loss_name(self) -> str: return name @functools.cached_property - def _distillation_language_model_loss_name(self) -> str: - name = "distillation_language_model_loss" + def _distillation_loss_name_unscaled(self) -> str: + name = "distillation_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @@ -472,34 +552,28 @@ def _distillation_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _distillation_loss_name_unscaled(self) -> str: - name = "distillation_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] - if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: - # unscaled CE loss (NTP) - loss_defs = [ + loss_defs = [ + LossDef(name=self._total_loss_name, formatted_name=_format_name(self._total_loss_name), count=count) + ] + if self._compute_lm_loss: + loss_defs.append( LossDef( - name=self._ce_loss_name_unscaled, - formatted_name=_format_name(self._ce_loss_name_unscaled), + name=self._lm_loss_name_unscaled, + formatted_name=_format_name(self._lm_loss_name_unscaled), count=count, ) - ] + ) if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) ) - if self._config.enable_dpo: + if self._compute_dpo_loss: loss_defs.append( LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) ) - if self._config.distillation_model is not None: + if self._compute_distillation_loss: loss_defs.append( LossDef( name=self._distillation_loss_name, @@ -515,15 +589,6 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) - # if we mix distillation loss and CE loss for NTP, we want to log both - if self._config.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef( - name=self._distillation_language_model_loss_name, - formatted_name=_format_name(self._distillation_language_model_loss_name), - count=count, - ) - ) return loss_defs @@ -544,4 +609,4 @@ def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: elif len(tensors) == 1: return tensors[0] else: - raise RuntimeError() + raise RuntimeError("No tensors to add.") From 9968aac14c439823c6850e0dcc4e2210b5ad2cf3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 17 Dec 2025 22:38:28 +0000 Subject: [PATCH 04/21] clean + tests --- fast_llm/layers/language_model/head.py | 24 ++---- tests/layers/test_lm_head.py | 107 +++++++++++++++++++++---- 2 files changed, 98 insertions(+), 33 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index e785c09e..8a460194 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -461,22 +461,7 @@ def _post_process_loss_and_grad( Assert.is_(lm_grad, None) if distillation_loss is not None: - # We need to scale the loss by (valid_tokens * num_micro_batches) / total_valid_tokens to correctly average the loss over micro-batches. - # The runner averages losses by dividing by num_micro_batches, so we need to account for that. - # Note: for grads this scaling is already in the 'grad_output' - total_valid_tokens = kwargs.get( - LanguageModelKwargs.total_valid_tokens - ) # number of not masked tokens across all micro-batches. - num_micro_batches = kwargs.get("num_micro_batches", 1) - - if loss_mask is None or total_valid_tokens is None: - loss_scalor_df = 1 - else: - valid_tokens = loss_mask.sum() - # Scale by (valid_tokens * num_micro_batches) / total_valid_tokens - # This accounts for the runner dividing by num_micro_batches - loss_scalor_df = (valid_tokens * num_micro_batches) / total_valid_tokens - distillation_loss = distillation_loss * loss_scalor_df + distillation_loss = distillation_loss if self.training and losses is not None: losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) distillation_loss = distillation_loss * self._config.distillation_loss_factor @@ -564,6 +549,13 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) + loss_defs.append( + LossDef( + name=self._lm_loss_name, + formatted_name=_format_name(self._lm_loss_name), + count=count, + ) + ) if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 623a30d8..88ff9d61 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -55,6 +55,8 @@ def _lm_head( logit_scale_factor: float = 1.0, logit_z_loss=0.0, distillation_loss_implementation: DistillationLossImpl = DistillationLossImpl.cross_entropy, + language_model_loss_factor: float = 1.0, + distillation_loss_factor: float = 1.0, ): hidden = torch.rms_norm( input_.to(rms_weight.dtype), @@ -69,23 +71,31 @@ def _lm_head( loss = _reverse_kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask ) - loss.backward(torch.full_like(loss, grad_output)) - return loss, None + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) + # Return scaled loss + return loss * distillation_loss_factor, None if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None if target.ndim == logits.ndim: + # Distillation loss (cross-entropy with soft targets) loss = torch.nn.functional.cross_entropy( logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" ) if loss_mask is not None: loss = loss * loss_mask.flatten() loss = loss.mean() + # Apply distillation_loss_factor + loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) + return loss * distillation_loss_factor, z_loss else: + # Language model loss (cross-entropy with hard labels) loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - loss.backward(torch.full_like(loss, grad_output)) - return loss, z_loss + # Apply language_model_loss_factor + loss.backward(torch.full_like(loss, grad_output * language_model_loss_factor)) + return loss * language_model_loss_factor, z_loss SEQUENCE_LENGTH = 200 @@ -154,6 +164,54 @@ def _lm_head( True, 1, ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "language_model_loss_factor": 0.0, + "track_language_model_loss": True, + "distillation_loss_factor": 1.0, + } + }, + {}, + False, + 1, + id="track_lm_zero_factor", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "language_model_loss_factor": 0.0, + "distillation_loss_factor": 0.0, + "track_language_model_loss": True, + "track_distillation_loss": True, + } + }, + {}, + False, + 1, + id="track_both_zero_factors", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "language_model_loss_factor": 0.0, + "distillation_loss_factor": 0.0, + "track_language_model_loss": False, + "track_distillation_loss": False, + } + }, + {}, + False, + 1, + marks=pytest.mark.xfail( + reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", + strict=True, + ), + id="zero_factors_no_tracking", + ), ), ) def test_lm_head( @@ -292,6 +350,10 @@ def test_lm_head( logit_scale_factor=head_config.logits_scale_factor, logit_z_loss=head_config.logit_z_loss, distillation_loss_implementation=head_config.distillation_loss_implementation, + language_model_loss_factor=( + head_config.language_model_loss_factor if head_config.language_model_loss_factor is not None else 1.0 + ), + distillation_loss_factor=head_config.distillation_loss_factor, ) # Prepare LM head inputs @@ -303,20 +365,27 @@ def test_lm_head( head_input = torch.stack((shared_hidden, input_.detach())).requires_grad_() output_grad = torch.randn_like(shared_hidden) - loss_name = f"language_model_loss_{prediction_distance}" if prediction_distance > 0 else "language_model_loss" - loss_keys = {loss_name} + lm_head_loss_name = f"lm_head_loss_{prediction_distance}" if prediction_distance > 0 else "lm_head_loss" + expected_loss_keys = {lm_head_loss_name} + if head._compute_lm_loss: + lm_loss_name_unscaled = ( + f"lm_loss_unscaled_{prediction_distance}" if prediction_distance > 0 else "lm_loss_unscaled" + ) + lm_loss_name = f"lm_loss_{prediction_distance}" if prediction_distance > 0 else "lm_loss" + + expected_loss_keys.add(lm_loss_name_unscaled) + expected_loss_keys.add(lm_loss_name) if ref_z_loss is not None: - loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - if head_config.distillation_model is not None: - loss_keys.add("distillation_loss") - if head_config.language_model_loss_factor > 0: - loss_keys.add("distillation_language_model_loss") + expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") + if head._compute_distillation_loss: + expected_loss_keys.add("distillation_loss") + expected_loss_keys.add("distillation_loss_unscaled") Assert.eq( {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, - {loss_key: 1 for loss_key in loss_keys}, + {loss_key: 1 for loss_key in expected_loss_keys}, ) - losses = {key: [] for key in loss_keys} + losses = {key: [] for key in expected_loss_keys} output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) @@ -325,16 +394,16 @@ def test_lm_head( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 ) * head_config.logits_scale_factor - Assert.eq(losses.keys(), loss_keys) - Assert.eq(len(losses[loss_name]), 1) + Assert.eq(losses.keys(), expected_loss_keys) + Assert.eq(len(losses[lm_head_loss_name]), 1) if ref_z_loss is not None: Assert.eq(len(losses["z_loss"]), 1) Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) - Assert.rms_close_relative(losses[loss_name][0], ref_loss, threshold, min_threshold) + Assert.rms_close_relative(losses[lm_head_loss_name][0], ref_loss, threshold, min_threshold) if head._is_last_head: - Assert.all_equal(output, losses[loss_name][0]) + Assert.all_equal(output, losses[lm_head_loss_name][0]) input_grad = head_input.grad else: Assert.all_equal(output, shared_hidden) @@ -344,3 +413,7 @@ def test_lm_head( Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 945c5a774bf30fbb088a818f12f5510e98f99bbb Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 17 Dec 2025 22:38:54 +0000 Subject: [PATCH 05/21] nvm --- tests/layers/test_lm_head.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 88ff9d61..c6d806db 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -413,7 +413,3 @@ def test_lm_head( Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From 4b6e3d7503b0cf8a93aef156a0328c2b6dc67cc8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 21:28:55 +0000 Subject: [PATCH 06/21] forward KL --- fast_llm/functional/config.py | 1 + fast_llm/functional/cross_entropy.py | 128 +++++++++++++++++++++++++ fast_llm/layers/language_model/head.py | 21 +++- 3 files changed, 149 insertions(+), 1 deletion(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 4cfc3b61..20ed99fd 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -102,6 +102,7 @@ class CrossEntropyImpl(str, enum.Enum): class DistillationLossImpl(str, enum.Enum): reverse_kl = "reverse_kl" + forward_kl = "forward_kl" cross_entropy = "cross_entropy" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8c9ea939..5a618eea 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -359,3 +359,131 @@ def reverse_kl_forward_backward( group=group, ) return distillation_loss, distillation_grad + + +@torch.compile +def _forward_kl_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + teacher_softmax_temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Forward KL: KL(p||q) where p=teacher, q=student. + This is reverse KL with roles swapped in the loss computation. + + Key insight: KL(p||q) = sum_i p_i * log(p_i/q_i) + = sum_i p_i * (log(p_i) - log(q_i)) + which is reverse KL with p and q swapped. + + However, we still need grad w.r.t. student logits, so gradient is different: + d/d(student_logits) KL(p||q) = student_probs - teacher_probs + """ + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel forward KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel forward KL") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + + # Compute log softmax for both teacher and student + teacher_log_probs = distributed_log_softmax(target.float(), group=group) + student_log_probs = distributed_log_softmax(logits, group=group) + + teacher_probs = teacher_log_probs.exp() + # Forward KL: p * log(p/q) = p * (log_p - log_q) + log_ratio = teacher_log_probs - student_log_probs + del teacher_log_probs + + # Compute loss: sum over vocab of teacher_probs * log_ratio + loss_terms = (teacher_probs * log_ratio).sum(dim=-1) + del log_ratio + + if loss_mask is not None: + valid = loss_mask.to(loss_terms.dtype) + loss_terms = loss_terms * valid + valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) + loss = loss_terms.sum() + + if group is not None: + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= valid_tokens + + if grad_output is not None: + # Gradient: d/d(student_logits) KL(p||q) = student_probs - teacher_probs + student_probs = student_log_probs.exp() + grad_base = student_probs - teacher_probs + del student_probs, teacher_probs, student_log_probs + + if loss_mask is not None: + grad_base.mul_(loss_mask.to(logits.dtype).unsqueeze(-1)) + + grad_base.mul_(grad_output / valid_tokens) + grad = grad_base.to(logits.dtype) + else: + grad = None + + return loss.detach_(), grad + + +def forward_kl_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + teacher_softmax_temperature: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, + sequence_parallel_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). + This is mode-covering (vs. mode-seeking for reverse KL) and useful for: + - Encouraging the model to cover all modes of the target distribution + - Spreading probability mass broadly across the target support + - Standard distillation scenarios where you want to match the full teacher distribution + + Key differences from reverse KL: + - Forward KL: KL(p||q) = mode-covering (spreads mass broadly) + - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) + + Takes: + logits: [BxS, V] or [B, S, V], where V is local vocab size + target: [BxS, V] or [B, S, V] (logits format) + loss_mask: [BxS] or [B, S] or None + ... + + Returns: + loss: Forward KL divergence loss + grad: Gradients w.r.t. logits + """ + + if sequence_parallel_logits: + # TODO: see hybrid dev branch where it is implemented + raise NotImplementedError("Sequence-parallel forward KL is not implemented yet, set vocab_parallel true") + + Assert.eq(target_format, TargetFormat.logits, msg="Forward KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + + # TODO: implement fused? + distillation_loss, distillation_grad = _forward_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=teacher_softmax_temperature, + group=group, + ) + return distillation_loss, distillation_grad diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 8a460194..b8a8f0cb 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -14,7 +14,11 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward +from fast_llm.functional.cross_entropy import ( + cross_entropy_forward_backward, + forward_kl_forward_backward, + reverse_kl_forward_backward, +) from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block @@ -390,6 +394,21 @@ def _logits_loss_forward_backward( sequence_parallel_logits=self._sequence_parallel_logits, ) + elif self._config.distillation_loss_implementation == DistillationLossImpl.forward_kl: + distillation_loss, distillation_grad = forward_kl_forward_backward( + logits.flatten(0, -2), + distillation_target, + loss_mask, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + group=group, + logits_scale_factor=self._config.logits_scale_factor, + teacher_softmax_temperature=self._config.teacher_softmax_temperature, + target_format=( + TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits + ), + sequence_parallel_logits=self._sequence_parallel_logits, + ) + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), From c5fefa0a13b1903bf88e7187790a94211b8d40cb Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 22:19:52 +0000 Subject: [PATCH 07/21] test forward kl --- tests/functional/test_cross_entropy.py | 43 ++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 72644d06..716c56ba 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -8,7 +8,11 @@ import torch from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward +from fast_llm.functional.cross_entropy import ( + cross_entropy_forward_backward, + forward_kl_forward_backward, + reverse_kl_forward_backward, +) from fast_llm.utils import Assert from tests.utils.utils import requires_cuda @@ -127,6 +131,41 @@ def test_reverse_kl(loss_masking, target_format): _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) +def _forward_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None): + # Manual reference: sum over vocab then average over all tokens (not just valid ones). + # Forward KL: KL(p||q) where p=teacher, q=student + logits = logits.detach().requires_grad_(True) + per_sample = torch.nn.functional.kl_div( + torch.log_softmax(logits.float(), dim=-1), + torch.log_softmax(target.float(), dim=-1), + reduction="none", + log_target=True, + ).sum(dim=-1) + if loss_mask is not None: + per_sample = per_sample * loss_mask + output = per_sample.sum() / per_sample.numel() + output.backward() + return output, logits.grad + + +@requires_cuda +@pytest.mark.slow +# TODO: Support the same parameterization as above in the reference implementation. +@pytest.mark.parametrize("loss_masking", [False, True]) +@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) +def test_forward_kl(loss_masking, target_format): + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + out_ref, grad_ref = _forward_kl_forward_backward_torch(logits, target, loss_mask) + out, grad = forward_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=1.0, + target_format=TargetFormat.logits, + ) + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) + + def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): try: torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) @@ -189,7 +228,7 @@ def _compare_parallel_cross_entropy( def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): success = True - for function in (reverse_kl_forward_backward, cross_entropy_forward_backward): + for function in (reverse_kl_forward_backward, forward_kl_forward_backward, cross_entropy_forward_backward): for target_format in (TargetFormat.logits,): for loss_masking in [False, True]: try: From 411959616793a78f49e76b9c0767d055ba2c1971 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 22:48:44 +0000 Subject: [PATCH 08/21] wip: report unscaled + kl loss --- fast_llm/layers/language_model/config.py | 35 ++++- fast_llm/layers/language_model/head.py | 158 +++++++++++++---------- 2 files changed, 122 insertions(+), 71 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 13c6d87e..807b3970 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -173,16 +173,37 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Track the unscaled language modeling loss for logging purposes. Will always do if language_model_loss_factor > 0.", hint=FieldHint.feature, ) - distillation_loss_factor: float = Field( - default=1.0, - desc="Factor to scale the distillation loss by when using distillation.", + track_forward_kl_loss: bool = Field( + default=False, + desc="Track the unscaled forward KL loss for logging purposes. Will always do if distillation_loss_implementation is forward_kl.", + hint=FieldHint.feature, + ) + track_reverse_kl_loss: bool = Field( + default=False, + desc="Track the unscaled reverse KL loss for logging purposes. Will always do if distillation_loss_implementation is reverse_kl.", hint=FieldHint.feature, ) - track_distillation_loss: bool = Field( + track_distillation_ce_loss: bool = Field( default=False, - desc="Track the unscaled distillation loss for logging purposes. Will always do if distillation_loss_factor > 0.", + desc="Track the unscaled distillation cross-entropy loss for logging purposes. Will always do if distillation_loss_implementation is cross_entropy.", + hint=FieldHint.feature, + ) + forward_kl_loss_factor: float = Field( + default=0.0, + desc="Factor to scale the forward KL loss by when using distillation with forward KL.", hint=FieldHint.feature, ) + reverse_kl_loss_factor: float = Field( + default=1.0, + desc="Factor to scale the reverse KL loss by when using distillation with reverse KL.", + hint=FieldHint.feature, + ) + distillation_ce_loss_factor: float = Field( + default=0.0, + desc="Factor to scale the distillation cross-entropy loss by when using distillation with cross-entropy.", + hint=FieldHint.feature, + ) + logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -254,7 +275,9 @@ def _validate(self) -> None: self.language_model_loss_factor = 0.0 super()._validate() if self.distillation_model is None: - Assert.is_(self.track_distillation_loss, False) + Assert.is_(self.track_forward_kl_loss, False) + Assert.is_(self.track_reverse_kl_loss, False) + Assert.is_(self.track_distillation_ce_loss, False) assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both @property diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b8a8f0cb..040dc55d 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -13,7 +13,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import ( cross_entropy_forward_backward, forward_kl_forward_backward, @@ -119,8 +119,18 @@ def __init__( self._compute_lm_loss = self.config.language_model_loss_factor > 0.0 or self.config.track_language_model_loss self._compute_dpo_loss = self._config.enable_dpo - self._compute_distillation_loss = self._config.distillation_model is not None and ( - self._config.distillation_loss_factor > 0.0 or self._config.track_distillation_loss + self._compute_rkl_loss = self._config.distillation_model is not None and ( + self._config.reverse_kl_loss_factor > 0.0 or self._config.track_reverse_kl_loss + ) + self._compute_kl_loss = self._config.distillation_model is not None and ( + self._config.forward_kl_loss_factor > 0.0 or self._config.track_forward_kl_loss + ) + self._compute_dist_ce_loss = self._config.distillation_model is not None and ( + self._config.distillation_ce_loss_factor > 0.0 or self._config.track_distillation_ce_loss + ) + + self._compute_distillation_loss = any( + [self._compute_rkl_loss, self._compute_kl_loss, self._compute_dist_ce_loss] ) def forward( @@ -378,13 +388,16 @@ def _logits_loss_forward_backward( else: lm_loss, lm_grad = None, None + distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad = None, None, None + distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss = None, None, None + if distillation_target is not None and self._compute_distillation_loss: - if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: - distillation_loss, distillation_grad = reverse_kl_forward_backward( + if self._compute_rkl_loss: + distillation_rkl_loss, distillation_rkl_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + grad_output=grad_output * self._loss_coefficient * self._config.reverse_kl_loss_factor, group=group, logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, @@ -394,12 +407,12 @@ def _logits_loss_forward_backward( sequence_parallel_logits=self._sequence_parallel_logits, ) - elif self._config.distillation_loss_implementation == DistillationLossImpl.forward_kl: - distillation_loss, distillation_grad = forward_kl_forward_backward( + if self._compute_kl_loss: + distillation_kl_loss, distillation_kl_grad = forward_kl_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + grad_output=grad_output * self._loss_coefficient * self._config.forward_kl_loss_factor, group=group, logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, @@ -409,13 +422,13 @@ def _logits_loss_forward_backward( sequence_parallel_logits=self._sequence_parallel_logits, ) - elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: - distillation_loss, distillation_grad = cross_entropy_forward_backward( + if self._compute_dist_ce_loss: + distillation_ce_loss, distillation_ce_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, group=group, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_ce_loss_factor, implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.logits, @@ -424,8 +437,6 @@ def _logits_loss_forward_backward( raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) - else: - distillation_loss, distillation_grad = None, None # TODO: de-allocate earlier. del logits @@ -434,10 +445,13 @@ def _logits_loss_forward_backward( dpo_grad, lm_loss, lm_grad, - distillation_loss, - distillation_grad, + distillation_rkl_loss, + distillation_rkl_grad, + distillation_kl_loss, + distillation_kl_grad, + distillation_ce_loss, + distillation_ce_grad, losses, - loss_mask, kwargs, ) @@ -449,10 +463,13 @@ def _post_process_loss_and_grad( dpo_grad: torch.Tensor | None, lm_loss: torch.Tensor | None, lm_grad: torch.Tensor | None, - distillation_loss: torch.Tensor | None, - distillation_grad: torch.Tensor | None, + distillation_rkl_loss: torch.Tensor | None, + distillation_rkl_grad: torch.Tensor | None, + distillation_kl_loss: torch.Tensor | None, + distillation_kl_grad: torch.Tensor | None, + distillation_ce_loss: torch.Tensor | None, + distillation_ce_grad: torch.Tensor | None, losses: dict | None, - loss_mask: torch.Tensor | None, kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -463,6 +480,7 @@ def _post_process_loss_and_grad( - Grads: gradients of the losses w.r.t. logits from different components, already scaled by loss factors. """ # Extremely explicit but easier to follow. + # TODO: simplify / shrten / make seperate dataclass? ############ if dpo_loss is not None: if self.training and losses is not None: @@ -471,28 +489,38 @@ def _post_process_loss_and_grad( Assert.is_(dpo_grad, None) if lm_loss is not None: - if self.training and losses is not None: - losses[self._lm_loss_name_unscaled].append(lm_loss.detach()) - lm_loss = lm_loss * self._config.language_model_loss_factor # does not need scaling by loss_scalor_df if self.training and losses is not None: losses[self._lm_loss_name].append(lm_loss.detach()) + lm_loss = lm_loss * self._config.language_model_loss_factor else: Assert.is_(lm_grad, None) - if distillation_loss is not None: - distillation_loss = distillation_loss + if distillation_rkl_loss is not None: + distillation_rkl_loss = distillation_rkl_loss if self.training and losses is not None: - losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) - distillation_loss = distillation_loss * self._config.distillation_loss_factor + losses[self._distillation_rkl_loss_name].append(distillation_rkl_loss.detach()) + distillation_rkl_loss = distillation_rkl_loss * self._config.distillation_loss_factor + else: + Assert.is_(distillation_rkl_grad, None) + if distillation_kl_loss is not None: + distillation_kl_loss = distillation_kl_loss + if self.training and losses is not None: + losses[self._distillation_kl_loss_name].append(distillation_kl_loss.detach()) + distillation_kl_loss = distillation_kl_loss * self._config.distillation_loss_factor + else: + Assert.is_(distillation_kl_grad, None) + if distillation_ce_loss is not None: + distillation_ce_loss = distillation_ce_loss if self.training and losses is not None: - losses[self._distillation_loss_name].append(distillation_loss.detach()) + losses[self._distillation_ce_loss_name].append(distillation_ce_loss.detach()) + distillation_ce_loss = distillation_ce_loss * self._config.distillation_loss_factor else: - Assert.is_(distillation_grad, None) + Assert.is_(distillation_ce_grad, None) ############ # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) - total_loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + grad = _add_tensors(dpo_grad, lm_grad, distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad) + total_loss = _add_tensors(dpo_loss, lm_loss, distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss) if losses is not None and total_loss is not None: losses[self._total_loss_name].append(total_loss.detach()) @@ -509,7 +537,7 @@ def _total_loss_name(self) -> str: return name @functools.cached_property - def _lm_loss_name_unscaled(self) -> str: + def _lm_loss_name(self) -> str: """ Unscaled language model cross-entropy loss. """ @@ -519,39 +547,36 @@ def _lm_loss_name_unscaled(self) -> str: return name @functools.cached_property - def _lm_loss_name(self) -> str: - """ - Scaled language model cross-entropy loss. - """ - name = "lm_loss" + def _z_loss_name(self) -> str: + name = "z_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _z_loss_name(self) -> str: - name = "z_loss" + def _dpo_loss_name(self) -> str: + name = "dpo_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _dpo_loss_name(self) -> str: - name = "dpo_loss" + def _distillation_kl_loss_name(self) -> str: + name = "distillation_kl_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _distillation_loss_name_unscaled(self) -> str: - name = "distillation_loss_unscaled" + def _distillation_rkl_loss_name(self) -> str: + name = "distillation_rkl_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _distillation_loss_name(self) -> str: - name = "distillation_loss" + def _distillation_ce_loss_name(self) -> str: + name = "distillation_ce_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @@ -568,13 +593,6 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) - loss_defs.append( - LossDef( - name=self._lm_loss_name, - formatted_name=_format_name(self._lm_loss_name), - count=count, - ) - ) if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) @@ -585,21 +603,31 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: ) if self._compute_distillation_loss: - loss_defs.append( - LossDef( - name=self._distillation_loss_name, - formatted_name=_format_name(self._distillation_loss_name), - count=count, - ) - ) # unscaled distillation loss for comparison purposes - loss_defs.append( - LossDef( - name=self._distillation_loss_name_unscaled, - formatted_name=_format_name(self._distillation_loss_name_unscaled), - count=count, + if self._compute_kl_loss: + loss_defs.append( + LossDef( + name=self._distillation_kl_loss_name, + formatted_name=_format_name(self._distillation_kl_loss_name), + count=count, + ) + ) + if self._compute_rkl_loss: + loss_defs.append( + LossDef( + name=self._distillation_rkl_loss_name, + formatted_name=_format_name(self._distillation_rkl_loss_name), + count=count, + ) + ) + if self._compute_dist_ce_loss: + loss_defs.append( + LossDef( + name=self._distillation_ce_loss_name, + formatted_name=_format_name(self._distillation_ce_loss_name), + count=count, + ) ) - ) return loss_defs From b55a0a428fb85dc3ce16ec061d1bed5ea2ac619a Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 13:42:48 +0000 Subject: [PATCH 09/21] loss config --- fast_llm/functional/cross_entropy.py | 2 + fast_llm/layers/language_model/config.py | 97 +---- fast_llm/layers/language_model/head.py | 408 +++++------------- .../layers/language_model/lm_head_losses.py | 280 ++++++++++++ 4 files changed, 405 insertions(+), 382 deletions(-) create mode 100644 fast_llm/layers/language_model/lm_head_losses.py diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 5a618eea..f534d8a7 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -314,6 +314,7 @@ def reverse_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -443,6 +444,7 @@ def forward_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 807b3970..6fc92eaa 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,11 +5,11 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.lm_head_losses import LossConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -135,75 +135,22 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the final normalization layer.", hint=FieldHint.architecture, ) + losses: dict[str, LossConfig] = Field( + default_factory=dict, + desc="A dictionary of loss names and their configurations.", + hint=FieldHint.core, + ) # TODO: Cleanup output_weight: ParameterConfig = Field( desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - cross_entropy_implementation: CrossEntropyImpl = Field( - default=CrossEntropyImpl.auto, - desc="Implementation for the cross-entropy computation.", - hint=FieldHint.performance, - ) - distillation_loss_implementation: DistillationLossImpl = Field( - default=DistillationLossImpl.cross_entropy, - desc="Implementation for the distillation cross-entropy computation.", - hint=FieldHint.performance, - ) cross_entropy_splits: int | None = Field( default=None, desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - logit_z_loss: float = Field( - default=0.0, - desc="Regularize the logits with Z-loss.", - doc="We recommend 1e-4 for stability, as used for training PaLM.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - language_model_loss_factor: float = Field( - default=None, - desc="Factor to scale the language modeling loss by when using distillation.", - hint=FieldHint.feature, - ) - track_language_model_loss: bool = Field( - default=False, - desc="Track the unscaled language modeling loss for logging purposes. Will always do if language_model_loss_factor > 0.", - hint=FieldHint.feature, - ) - track_forward_kl_loss: bool = Field( - default=False, - desc="Track the unscaled forward KL loss for logging purposes. Will always do if distillation_loss_implementation is forward_kl.", - hint=FieldHint.feature, - ) - track_reverse_kl_loss: bool = Field( - default=False, - desc="Track the unscaled reverse KL loss for logging purposes. Will always do if distillation_loss_implementation is reverse_kl.", - hint=FieldHint.feature, - ) - track_distillation_ce_loss: bool = Field( - default=False, - desc="Track the unscaled distillation cross-entropy loss for logging purposes. Will always do if distillation_loss_implementation is cross_entropy.", - hint=FieldHint.feature, - ) - forward_kl_loss_factor: float = Field( - default=0.0, - desc="Factor to scale the forward KL loss by when using distillation with forward KL.", - hint=FieldHint.feature, - ) - reverse_kl_loss_factor: float = Field( - default=1.0, - desc="Factor to scale the reverse KL loss by when using distillation with reverse KL.", - hint=FieldHint.feature, - ) - distillation_ce_loss_factor: float = Field( - default=0.0, - desc="Factor to scale the distillation cross-entropy loss by when using distillation with cross-entropy.", - hint=FieldHint.feature, - ) - logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -212,10 +159,10 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - teacher_softmax_temperature: float = Field( - default=1.0, - desc="Divides distillation target logits by this factor.", - doc="Divides distillation target logits by this factor.", + logit_z_loss: float = Field( + default=0.0, + desc="Regularize the logits with Z-loss.", + doc="We recommend 1e-4 for stability, as used for training PaLM.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) @@ -224,11 +171,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Name of the reference model to use for dpo.", hint=FieldHint.feature, ) - dpo_beta: float | None = Field( - default=1.0, - desc="Beta value for DPO loss.", - hint=FieldHint.feature, - ) distillation_model: str | None = Field( default=None, desc="Name of the reference model to use for knowledge distillation." @@ -268,16 +210,17 @@ def layer_class(self) -> "type[LanguageModelHead]": def _validate(self) -> None: with self._set_implicit_default(): - if self.language_model_loss_factor is None: - if self.distillation_model is None: - self.language_model_loss_factor = 1.0 - else: - self.language_model_loss_factor = 0.0 + if not self.losses: + self.losses = { + "lm_loss": LossConfig._from_dict( + {"type": "cross_entropy_lm_loss", "weight_scalor": 1.0, "log_it": True} + ) + } + + for loss_config in self.losses.values(): + if "dist" in loss_config.type: + assert self.distillation_model is not None, "Distillation loss requires a distillation model." super()._validate() - if self.distillation_model is None: - Assert.is_(self.track_forward_kl_loss, False) - Assert.is_(self.track_reverse_kl_loss, False) - Assert.is_(self.track_distillation_ce_loss, False) assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both @property diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 040dc55d..f23bb6f1 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -13,13 +13,6 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import ( - cross_entropy_forward_backward, - forward_kl_forward_backward, - reverse_kl_forward_backward, -) -from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames @@ -31,6 +24,7 @@ LanguageModelHeadConfig, LanguageModelKwargs, ) +from fast_llm.layers.language_model.lm_head_losses import Targets, _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -91,16 +85,6 @@ def __init__( if self._config.cross_entropy_splits is not None and self._sequence_parallel: assert not self._vocab_parallel - if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_implementation - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._vocab_parallel: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused - self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) self.final_norm = self._config.normalization.get_layer( @@ -116,22 +100,10 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - - self._compute_lm_loss = self.config.language_model_loss_factor > 0.0 or self.config.track_language_model_loss - self._compute_dpo_loss = self._config.enable_dpo - self._compute_rkl_loss = self._config.distillation_model is not None and ( - self._config.reverse_kl_loss_factor > 0.0 or self._config.track_reverse_kl_loss - ) - self._compute_kl_loss = self._config.distillation_model is not None and ( - self._config.forward_kl_loss_factor > 0.0 or self._config.track_forward_kl_loss - ) - self._compute_dist_ce_loss = self._config.distillation_model is not None and ( - self._config.distillation_ce_loss_factor > 0.0 or self._config.track_distillation_ce_loss - ) - - self._compute_distillation_loss = any( - [self._compute_rkl_loss, self._compute_kl_loss, self._compute_dist_ce_loss] - ) + self._formatted_loss_names = { + loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) + for loss_name, loss_config in self._config.losses.items() + } def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -203,22 +175,25 @@ def _forward_backward( else: return loss, None - def _get_targets( - self, kwargs: dict - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: - # Loss mask for distillation. (Labels are already masked.) + def _get_targets(self, kwargs: dict) -> Targets | None: + ( + lm_target, + dpo_target, + reference_model_logits, + loss_mask, + chosen_spans, + rejected_spans, + dpo_reference_model_logits, + ) = (None, None, None, None, None, None, None) if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) - lm_target = None - distillation_target = None - loss_mask = None + chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) + rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) + dpo_reference_model_logits = (kwargs.get(f"{self._config.dpo_reference_model}_logits"),) else: - dpo_target = None - if self._config.distillation_model is None: - distillation_target, loss_mask = None, None - else: + if self._config.distillation_model is not None: # Target is reference model logits. - distillation_target = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) + reference_model_logits = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) if loss_mask is not None: loss_mask = loss_mask.flatten() @@ -240,12 +215,29 @@ def _get_targets( else lm_target[:, lm_target_slice] ).flatten() - targets = (dpo_target, lm_target, distillation_target, loss_mask) if self._sequence_parallel_logits: - targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] - if not any(target is not None for target in targets): - # Simplify so we don't have to check every time. - targets = None + if dpo_target is not None: + dpo_target = split_op(dpo_target, self._parallel_dim.group, 0) + if lm_target is not None: + lm_target = split_op(lm_target, self._parallel_dim.group, 0) + if loss_mask is not None: + loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) + if reference_model_logits is not None: + reference_model_logits = split_op(reference_model_logits, self._parallel_dim.group, 0) + + targets = Targets( + dpo_target=dpo_target, + lm_target=lm_target, + loss_mask=loss_mask, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + reference_model_logits=reference_model_logits, + dpo_reference_model_logits=dpo_reference_model_logits, + ) + + # Return None if no targets are set + if not targets.has_any_target(): + return None return targets def get_output_weights(self) -> list[torch.Tensor]: @@ -254,7 +246,7 @@ def get_output_weights(self) -> list[torch.Tensor]: def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None, + targets: Targets | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -285,15 +277,34 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad = torch.empty_like(logit_input) else: logit_input_grad = None + + # Extract target tensors for splitting (keep same order as original tuple) + target_tensors = [ + targets.lm_target, + targets.dpo_target, + targets.reference_model_logits, + targets.loss_mask, + ] split_size = div( - get_unique(target.size(0) for target in targets if target is not None), + get_unique(target.size(0) for target in target_tensors if target is not None), self._config.cross_entropy_splits, ) tensors_split = [ [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) - for tensor in [logit_input, *targets, logit_input_grad] + for tensor in [logit_input, *target_tensors, logit_input_grad] ] - for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): + for logit_input_, lm_target_, dpo_target_, reference_model_logits_, loss_mask_, logit_input_grad_ in zip( + *tensors_split, strict=True + ): + targets_ = Targets( + lm_target=lm_target_, + dpo_target=dpo_target_, + reference_model_logits=reference_model_logits_, + loss_mask=loss_mask_, + chosen_spans=targets.chosen_spans, + rejected_spans=targets.rejected_spans, + dpo_reference_model_logits=targets.dpo_reference_model_logits, + ) loss_, grad_ = self._logits_loss_forward_backward( logit_input_, targets_, @@ -319,7 +330,7 @@ def _logits_cross_entropy_forward_backward_split( def _logits_loss_forward_backward( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None], + targets: Targets | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -334,6 +345,7 @@ def _logits_loss_forward_backward( sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) + # TODO: also move to lm_head_losses? if self._config.logit_z_loss > 0.0: logits = z_loss( logits, @@ -359,175 +371,48 @@ def _logits_loss_forward_backward( if targets is None: return logits * self._config.logits_scale_factor, None - dpo_target, lm_target, distillation_target, loss_mask = targets - if dpo_target is not None: - dpo_loss, dpo_grad = compute_dpo_loss( + total_loss, grad = None, None + for loss_name, loss_config in self._config.losses.items(): + if loss_config.weight_scalor == 0.0 and not loss_config.log_it: + continue + # losses are returned unscaled but the grads are already scaled + # we log unscaled losses seperately and the scaled total loss + loss_unscaled_, grad_ = loss_config.compute_loss( logits, - dpo_target, - kwargs.get(f"{self._config.dpo_reference_model}_logits"), - kwargs[LanguageModelKwargs.chosen_spans], - kwargs[LanguageModelKwargs.rejected_spans], - self._config.dpo_beta, - grad_output * self._loss_coefficient, - ) - else: - dpo_loss, dpo_grad = None, None - - if lm_target is not None and self._compute_lm_loss: - lm_loss, lm_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - lm_target, - None, + targets, + grad_output=( + grad_output * self._loss_coefficient * loss_config.weight_scalor + if grad_output is not None + else None + ), group=group, - grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, - implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.labels, + vocab_parallel=self._vocab_parallel, ) - else: - lm_loss, lm_grad = None, None - - distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad = None, None, None - distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss = None, None, None - - if distillation_target is not None and self._compute_distillation_loss: - if self._compute_rkl_loss: - distillation_rkl_loss, distillation_rkl_grad = reverse_kl_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.reverse_kl_loss_factor, - group=group, - logits_scale_factor=self._config.logits_scale_factor, - teacher_softmax_temperature=self._config.teacher_softmax_temperature, - target_format=( - TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits - ), - sequence_parallel_logits=self._sequence_parallel_logits, - ) + loss_ = loss_unscaled_ * loss_config.weight_scalor * self._loss_coefficient - if self._compute_kl_loss: - distillation_kl_loss, distillation_kl_grad = forward_kl_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.forward_kl_loss_factor, - group=group, - logits_scale_factor=self._config.logits_scale_factor, - teacher_softmax_temperature=self._config.teacher_softmax_temperature, - target_format=( - TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits - ), - sequence_parallel_logits=self._sequence_parallel_logits, - ) + if losses is not None and loss_config.log_it: + losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) - if self._compute_dist_ce_loss: - distillation_ce_loss, distillation_ce_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - group=group, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_ce_loss_factor, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.logits, - ) + if total_loss is None: + total_loss = loss_ else: - raise ValueError( - f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" - ) - - # TODO: de-allocate earlier. - del logits - loss, grad = self._post_process_loss_and_grad( - dpo_loss, - dpo_grad, - lm_loss, - lm_grad, - distillation_rkl_loss, - distillation_rkl_grad, - distillation_kl_loss, - distillation_kl_grad, - distillation_ce_loss, - distillation_ce_grad, - losses, - kwargs, - ) - - return loss, output_parallel_linear_backward(grad, context) if self.training else None - - def _post_process_loss_and_grad( - self, - dpo_loss: torch.Tensor | None, - dpo_grad: torch.Tensor | None, - lm_loss: torch.Tensor | None, - lm_grad: torch.Tensor | None, - distillation_rkl_loss: torch.Tensor | None, - distillation_rkl_grad: torch.Tensor | None, - distillation_kl_loss: torch.Tensor | None, - distillation_kl_grad: torch.Tensor | None, - distillation_ce_loss: torch.Tensor | None, - distillation_ce_grad: torch.Tensor | None, - losses: dict | None, - kwargs, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - If loss is provided (i.e. not None) it will be logged in scaled and unscaled version. The total loss is also logged. - - Arguments: - - Losses: unscaled losses from different components (DPO, LM CE, Distillation) - - Grads: gradients of the losses w.r.t. logits from different components, already scaled by loss factors. - """ - # Extremely explicit but easier to follow. - # TODO: simplify / shrten / make seperate dataclass? - ############ - if dpo_loss is not None: - if self.training and losses is not None: - losses[self._dpo_loss_name].append(dpo_loss.detach()) - else: - Assert.is_(dpo_grad, None) + total_loss = total_loss + loss_ - if lm_loss is not None: - if self.training and losses is not None: - losses[self._lm_loss_name].append(lm_loss.detach()) - lm_loss = lm_loss * self._config.language_model_loss_factor - else: - Assert.is_(lm_grad, None) - - if distillation_rkl_loss is not None: - distillation_rkl_loss = distillation_rkl_loss - if self.training and losses is not None: - losses[self._distillation_rkl_loss_name].append(distillation_rkl_loss.detach()) - distillation_rkl_loss = distillation_rkl_loss * self._config.distillation_loss_factor - else: - Assert.is_(distillation_rkl_grad, None) - if distillation_kl_loss is not None: - distillation_kl_loss = distillation_kl_loss - if self.training and losses is not None: - losses[self._distillation_kl_loss_name].append(distillation_kl_loss.detach()) - distillation_kl_loss = distillation_kl_loss * self._config.distillation_loss_factor - else: - Assert.is_(distillation_kl_grad, None) - if distillation_ce_loss is not None: - distillation_ce_loss = distillation_ce_loss - if self.training and losses is not None: - losses[self._distillation_ce_loss_name].append(distillation_ce_loss.detach()) - distillation_ce_loss = distillation_ce_loss * self._config.distillation_loss_factor - else: - Assert.is_(distillation_ce_grad, None) + if grad_ is not None: + if grad is None: + grad = grad_ + else: + grad = grad + grad_ - ############ - # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad) - total_loss = _add_tensors(dpo_loss, lm_loss, distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss) if losses is not None and total_loss is not None: - losses[self._total_loss_name].append(total_loss.detach()) + losses[self._total_head_loss_name].append(total_loss.detach()) - return total_loss, grad + return total_loss, output_parallel_linear_backward(grad, context) if self.training else None @functools.cached_property - def _total_loss_name(self) -> str: + def _total_head_loss_name(self) -> str: """ Combined total scaled loss used for training. """ @@ -536,16 +421,6 @@ def _total_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _lm_loss_name(self) -> str: - """ - Unscaled language model cross-entropy loss. - """ - name = "lm_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - @functools.cached_property def _z_loss_name(self) -> str: name = "z_loss" @@ -553,81 +428,18 @@ def _z_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _dpo_loss_name(self) -> str: - name = "dpo_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_kl_loss_name(self) -> str: - name = "distillation_kl_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_rkl_loss_name(self) -> str: - name = "distillation_rkl_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_ce_loss_name(self) -> str: - name = "distillation_ce_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [ - LossDef(name=self._total_loss_name, formatted_name=_format_name(self._total_loss_name), count=count) - ] - if self._compute_lm_loss: - loss_defs.append( - LossDef( - name=self._lm_loss_name_unscaled, - formatted_name=_format_name(self._lm_loss_name_unscaled), - count=count, - ) + LossDef( + name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count ) - if self._config.logit_z_loss: - loss_defs.append( - LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) - ) - if self._compute_dpo_loss: - loss_defs.append( - LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) - ) - - if self._compute_distillation_loss: - # unscaled distillation loss for comparison purposes - if self._compute_kl_loss: - loss_defs.append( - LossDef( - name=self._distillation_kl_loss_name, - formatted_name=_format_name(self._distillation_kl_loss_name), - count=count, - ) - ) - if self._compute_rkl_loss: - loss_defs.append( - LossDef( - name=self._distillation_rkl_loss_name, - formatted_name=_format_name(self._distillation_rkl_loss_name), - count=count, - ) - ) - if self._compute_dist_ce_loss: - loss_defs.append( - LossDef( - name=self._distillation_ce_loss_name, - formatted_name=_format_name(self._distillation_ce_loss_name), - count=count, - ) + ] + for loss_name, loss_config in self._config.losses.items(): + if loss_config.log_it: + loss_def: LossDef = loss_config.get_loss_def( + name=loss_name, count=count, prediction_distance=self._prediction_distance ) + loss_defs.append(loss_def) return loss_defs @@ -635,17 +447,3 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: def heads(self): # For compatibility with MTP. return [self] - - -def _format_name(name: str) -> str: - return name.replace("_", " ") - - -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: - tensors = [tensor for tensor in tensors if tensor is not None] - if len(tensors) > 1: - return sum(tensors) - elif len(tensors) == 1: - return tensors[0] - else: - raise RuntimeError("No tensors to add.") diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py new file mode 100644 index 00000000..cc8e5ebc --- /dev/null +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -0,0 +1,280 @@ +import abc +import dataclasses +import logging +import typing + +import torch + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.core.distributed import ProcessGroup +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +# +# CE loss on lm_targets for standard LM training. Here targets are already masked. +# CE loss for distillation: cross entropuy that uses reference_model_logits as soft targets, not implemented, TODO. +# Forward KL divergence loss on reference_model_logits for distillation (mode-covering). +# Reverse KL divergence loss on reference_model_logits for distillation (mode-seeking). +# DPO loss for alignment using chosen and rejected spans. +# + + +def _format_name(name: str) -> str: + return name.replace("_", " ") + + +@dataclasses.dataclass +class Targets: + lm_target: torch.Tensor | None = None + dpo_target: torch.Tensor | None = None + loss_mask: torch.Tensor | None = None + chosen_spans: list[list[tuple[int, int]]] | None = None + rejected_spans: list[list[tuple[int, int]]] | None = None + reference_model_logits: torch.Tensor | None = None + dpo_reference_model_logits: torch.Tensor | None = None + + def has_any_target(self) -> bool: + return any(getattr(self, field.name) is not None for field in dataclasses.fields(self)) + + +@config_class(registry=True) +class LossConfig(Config): + """ + Losses canm register themselves + using @config_class(dynamic_type={LossConfig: "loss_type_name"}) + """ + + _name: typing.ClassVar[str] + _abstract: typing.ClassVar[bool] = True + + weight_scalor: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Weight for this loss in the total loss computation.", + valid=check_field(Assert.geq, 0.0), + ) + + log_it: bool = Field( + default=True, + hint=FieldHint.optional, + desc="Whether to log this loss.", + ) + + @abc.abstractmethod + def compute_loss( + self, + logits: torch.Tensor, + target: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + pass + + def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: + name = self.get_formatted_name(name, prediction_distance) + return LossDef( + name=name, + formatted_name=_format_name(name), + count=count, + dtype=DataType.float32, + ) + + def _validate(self): + Assert.geq(self.weight_scalor, 0.0) + if self.weight_scalor > 0.0: + with self._set_implicit_default(): + if "log_it" not in self._explicit_fields: + self.log_it = True + super()._validate() + + def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: + name = f"{self._name}({name})" + if prediction_distance is not None: + name = f"{name}_{prediction_distance}" + return name + + +@config_class(dynamic_type={LossConfig: "cross_entropy_lm_loss"}) +class CrossEntropyLMLossConfig(LossConfig): + _name: typing.ClassVar[str] = "CE" + _abstract: typing.ClassVar[bool] = False + + implementation: CrossEntropyImpl = Field( + default=CrossEntropyImpl.auto, + desc="Implementation for the cross-entropy computation.", + hint=FieldHint.performance, + ) + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax (used in distillation losses).", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.cross_entropy import cross_entropy_forward_backward + + target = targets.lm_target + if target is None: + raise ValueError("CrossEntropyLoss requires lm_target to be set in Targets") + implementation = self.implementation + if implementation == CrossEntropyImpl.auto: + if vocab_parallel: + implementation = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + implementation = CrossEntropyImpl.triton + else: + implementation = CrossEntropyImpl.fused + + return cross_entropy_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, # Labels are already masked + grad_output=grad_output, + group=group, + implementation=implementation, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.labels, + **kwargs, + ) + + +@config_class(dynamic_type={LossConfig: "fkl_dist"}) +class ForwardKLLossConfig(LossConfig): + """Forward KL divergence KL(p||q) for distillation (mode-covering).""" + + _name: typing.ClassVar[str] = "FwdKL" + _abstract: typing.ClassVar[bool] = False + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.cross_entropy import forward_kl_forward_backward + + target = targets.reference_model_logits + if target is None: + raise ValueError("ForwardKLLoss requires distillation_target to be set in Targets") + + return forward_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=targets.loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + **kwargs, + ) + + +@config_class(dynamic_type={LossConfig: "revkl_dist"}) +class ReverseKLLossConfig(LossConfig): + """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" + + _name: typing.ClassVar[str] = "RevKL" + _abstract: typing.ClassVar[bool] = False + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.cross_entropy import reverse_kl_forward_backward + + # Use distillation_target for KL losses + target = targets.reference_model_logits + if target is None: + raise ValueError("ReverseKLLoss requires distillation_target to be set in Targets") + + return reverse_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=targets.loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + **kwargs, + ) + + +@config_class(dynamic_type={LossConfig: "dpo"}) +class DPOLossConfig(LossConfig): + """Direct Preference Optimization (DPO) loss for alignment.""" + + _name: typing.ClassVar[str] = "DPO" + _abstract: typing.ClassVar[bool] = False + + beta: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Beta parameter for DPO loss (controls strength of preference optimization).", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.dpo import compute_dpo_loss + + return compute_dpo_loss( + logits=logits, + targets=targets.dpo_target, + reference_model_logits=targets.dpo_reference_model_logits, + chosen_spans=targets.chosen_spans, + rejected_spans=targets.rejected_spans, + beta=self.beta, + grad_output=grad_output, + ) From 097baeb4c2396575066f96ced831771e0054ea76 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 14:24:57 +0000 Subject: [PATCH 10/21] wip --- fast_llm/functional/config.py | 6 - fast_llm/layers/language_model/config.py | 4 +- fast_llm/layers/language_model/head.py | 8 +- .../layers/language_model/lm_head_losses.py | 6 +- tests/layers/test_lm_head.py | 188 +++++++++--------- tests/utils/model_configs.py | 8 +- 6 files changed, 108 insertions(+), 112 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 20ed99fd..511c2d9f 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -100,12 +100,6 @@ class CrossEntropyImpl(str, enum.Enum): triton = "triton" -class DistillationLossImpl(str, enum.Enum): - reverse_kl = "reverse_kl" - forward_kl = "forward_kl" - cross_entropy = "cross_entropy" - - class TargetFormat(enum.StrEnum): labels = "labels" logits = "logits" diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 6fc92eaa..786d312d 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -212,9 +212,7 @@ def _validate(self) -> None: with self._set_implicit_default(): if not self.losses: self.losses = { - "lm_loss": LossConfig._from_dict( - {"type": "cross_entropy_lm_loss", "weight_scalor": 1.0, "log_it": True} - ) + "lm_loss": LossConfig._from_dict({"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}) } for loss_config in self.losses.values(): diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index f23bb6f1..c8c3be79 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -374,7 +374,7 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.weight_scalor == 0.0 and not loss_config.log_it: + if loss_config.factor == 0.0 and not loss_config.log_it: continue # losses are returned unscaled but the grads are already scaled # we log unscaled losses seperately and the scaled total loss @@ -382,15 +382,13 @@ def _logits_loss_forward_backward( logits, targets, grad_output=( - grad_output * self._loss_coefficient * loss_config.weight_scalor - if grad_output is not None - else None + grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None ), group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, ) - loss_ = loss_unscaled_ * loss_config.weight_scalor * self._loss_coefficient + loss_ = loss_unscaled_ * loss_config.factor * self._loss_coefficient if losses is not None and loss_config.log_it: losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index cc8e5ebc..a231efa5 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -54,7 +54,7 @@ class LossConfig(Config): _name: typing.ClassVar[str] _abstract: typing.ClassVar[bool] = True - weight_scalor: float = Field( + factor: float = Field( default=1.0, hint=FieldHint.core, desc="Weight for this loss in the total loss computation.", @@ -90,8 +90,8 @@ def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | Non ) def _validate(self): - Assert.geq(self.weight_scalor, 0.0) - if self.weight_scalor > 0.0: + Assert.geq(self.factor, 0.0) + if self.factor > 0.0: with self._set_implicit_default(): if "log_it" not in self._explicit_fields: self.log_it = True diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c6d806db..917bb7ef 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -5,7 +5,7 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead @@ -119,99 +119,99 @@ def _lm_head( ({"tied_embedding_weight": True}, {}, False, 1), ({}, {}, False, 2), ({}, {}, True, 1), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - } - }, - {}, - False, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, - } - }, - {}, - False, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - "language_model_loss_factor": 1.0, - } - }, - {}, - True, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, - } - }, - {}, - True, - 1, - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "language_model_loss_factor": 0.0, - "track_language_model_loss": True, - "distillation_loss_factor": 1.0, - } - }, - {}, - False, - 1, - id="track_lm_zero_factor", - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "language_model_loss_factor": 0.0, - "distillation_loss_factor": 0.0, - "track_language_model_loss": True, - "track_distillation_loss": True, - } - }, - {}, - False, - 1, - id="track_both_zero_factors", - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "language_model_loss_factor": 0.0, - "distillation_loss_factor": 0.0, - "track_language_model_loss": False, - "track_distillation_loss": False, - } - }, - {}, - False, - 1, - marks=pytest.mark.xfail( - reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", - strict=True, - ), - id="zero_factors_no_tracking", - ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + # } + # }, + # {}, + # False, + # 1, + # ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # } + # }, + # {}, + # False, + # 1, + # ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + # "language_model_loss_factor": 1.0, + # } + # }, + # {}, + # True, + # 1, + # ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # } + # }, + # {}, + # True, + # 1, + # ), + # pytest.param( + # { + # "head": { + # "distillation_model": "distillation", + # "language_model_loss_factor": 0.0, + # "track_language_model_loss": True, + # "distillation_loss_factor": 1.0, + # } + # }, + # {}, + # False, + # 1, + # id="track_lm_zero_factor", + # ), + # pytest.param( + # { + # "head": { + # "distillation_model": "distillation", + # "language_model_loss_factor": 0.0, + # "distillation_loss_factor": 0.0, + # "track_language_model_loss": True, + # "track_distillation_loss": True, + # } + # }, + # {}, + # False, + # 1, + # id="track_both_zero_factors", + # ), + # pytest.param( + # { + # "head": { + # "distillation_model": "distillation", + # "language_model_loss_factor": 0.0, + # "distillation_loss_factor": 0.0, + # "track_language_model_loss": False, + # "track_distillation_loss": False, + # } + # }, + # {}, + # False, + # 1, + # marks=pytest.mark.xfail( + # reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", + # strict=True, + # ), + # id="zero_factors_no_tracking", + # ), ), ) def test_lm_head( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6156cb70..f4e3ecea 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -552,6 +552,12 @@ def _update_and_add_testing_config( "mistral_distill_logits", updates={ ("model", "base_model", "head", "distillation_model"): "teacher", + ("model", "base_model", "head", "losses"): { + "distillation_loss": { + "type": "revkl_dist", + "factor": 1.0, + }, + }, ("batch", "use_loss_masking_spans"): True, ("reference_models"): { "teacher": { @@ -599,7 +605,7 @@ def _update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", updates={ - ("model", "base_model", "head", "distillation_loss_factor"): 0.001, + ("model", "base_model", "head", "losses", "distillation_loss", "factor"): 0.001, ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.1, ("reference_models"): { From d773d986d54ed3cc1729d9bd8992af116c8f20de Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 16:47:11 +0000 Subject: [PATCH 11/21] tests --- fast_llm/layers/language_model/head.py | 4 + tests/layers/test_lm_head.py | 340 +++++++++++++++---------- 2 files changed, 214 insertions(+), 130 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c8c3be79..c47a87de 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -432,6 +432,10 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count ) ] + if self._config.logit_z_loss > 0.0: + loss_defs.append( + LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) + ) for loss_name, loss_config in self._config.losses.items(): if loss_config.log_it: loss_def: LossDef = loss_config.get_loss_def( diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 917bb7ef..5835b667 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,6 +9,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.lm_head_losses import LossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -43,6 +44,20 @@ def _reverse_kl_loss( return loss +def _kl_loss( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + teacher_softmax_temperature: float = 1.0, +): + return _reverse_kl_loss( + target, + logits, + loss_mask, + teacher_softmax_temperature, + ) + + def _lm_head( input_: torch.Tensor, target: torch.Tensor, @@ -54,9 +69,7 @@ def _lm_head( grad_output: float = 1.0, logit_scale_factor: float = 1.0, logit_z_loss=0.0, - distillation_loss_implementation: DistillationLossImpl = DistillationLossImpl.cross_entropy, - language_model_loss_factor: float = 1.0, - distillation_loss_factor: float = 1.0, + losses: dict[str, LossConfig], ): hidden = torch.rms_norm( input_.to(rms_weight.dtype), @@ -66,36 +79,34 @@ def _lm_head( ) logits = torch.nn.functional.linear(hidden, logit_weight).float() - if distillation_loss_implementation == DistillationLossImpl.reverse_kl: - Assert.eq(logits.shape, target.shape) - loss = _reverse_kl_loss( - (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask - ) - # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) - # Return scaled loss - return loss * distillation_loss_factor, None + if "dist_loss" in losses: + if losses["dist_loss"].type == "revkl_dist": + Assert.eq(logits.shape, target.shape) + loss = _reverse_kl_loss( + (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask + ) + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + # Return scaled loss + return loss * losses["dist_loss"].factor, None + elif losses["dist_loss"].type == "fkl_dist": + Assert.eq(logits.shape, target.shape) + loss = _kl_loss( + (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask + ) + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + # Return scaled loss + return loss * losses["dist_loss"].factor, None if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None - if target.ndim == logits.ndim: - # Distillation loss (cross-entropy with soft targets) - loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" - ) - if loss_mask is not None: - loss = loss * loss_mask.flatten() - loss = loss.mean() - # Apply distillation_loss_factor - loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) - return loss * distillation_loss_factor, z_loss - else: - # Language model loss (cross-entropy with hard labels) - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - # Apply language_model_loss_factor - loss.backward(torch.full_like(loss, grad_output * language_model_loss_factor)) - return loss * language_model_loss_factor, z_loss + # Language model loss (cross-entropy with hard labels) + loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) + # Apply language_model_loss_factor + loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].factor)) + return loss * losses["lm_loss"].factor, z_loss SEQUENCE_LENGTH = 200 @@ -119,99 +130,169 @@ def _lm_head( ({"tied_embedding_weight": True}, {}, False, 1), ({}, {}, False, 2), ({}, {}, True, 1), + # Skip CE distillation for now - not yet implemented in new losses system # ( # { # "head": { # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - # } - # }, - # {}, - # False, - # 1, - # ), - # ( - # { - # "head": { - # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # "losses": { + # "lm_loss": { + # "type": "cross_entropy_lm_loss", + # "weight_scalor": 0.0, + # "log_it": False, + # }, + # "dist_loss": { + # "type": "cross_entropy_dist", # TODO: Not implemented yet + # "weight_scalor": 1.0, + # "log_it": True, + # } + # } # } # }, # {}, # False, # 1, # ), + ( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": False, + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, + }, + }, + } + }, + {}, + False, + 1, + ), + # Skip - CE distillation not implemented # ( # { # "head": { # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - # "language_model_loss_factor": 1.0, - # } - # }, - # {}, - # True, - # 1, - # ), - # ( - # { - # "head": { - # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # "losses": { + # "lm_loss": { + # "type": "cross_entropy_lm_loss", + # "weight_scalor": 1.0, + # "log_it": True, + # }, + # "dist_loss": { + # "type": "cross_entropy_dist", # TODO + # "weight_scalor": 1.0, + # "log_it": True, + # } + # } # } # }, # {}, # True, # 1, # ), - # pytest.param( - # { - # "head": { - # "distillation_model": "distillation", - # "language_model_loss_factor": 0.0, - # "track_language_model_loss": True, - # "distillation_loss_factor": 1.0, - # } - # }, - # {}, - # False, - # 1, - # id="track_lm_zero_factor", - # ), - # pytest.param( - # { - # "head": { - # "distillation_model": "distillation", - # "language_model_loss_factor": 0.0, - # "distillation_loss_factor": 0.0, - # "track_language_model_loss": True, - # "track_distillation_loss": True, - # } - # }, - # {}, - # False, - # 1, - # id="track_both_zero_factors", - # ), - # pytest.param( - # { - # "head": { - # "distillation_model": "distillation", - # "language_model_loss_factor": 0.0, - # "distillation_loss_factor": 0.0, - # "track_language_model_loss": False, - # "track_distillation_loss": False, - # } - # }, - # {}, - # False, - # 1, - # marks=pytest.mark.xfail( - # reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", - # strict=True, - # ), - # id="zero_factors_no_tracking", - # ), + ( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": False, + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, + }, + }, + } + }, + {}, + True, + 1, + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": True, # tracking even with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, + }, + }, + } + }, + {}, + False, + 1, + id="track_lm_zero_factor", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": True, # tracking with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 0.0, + "log_it": True, # tracking with zero weight + }, + }, + } + }, + {}, + False, + 1, + id="track_both_zero_factors", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": False, # not tracking with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 0.0, + "log_it": False, # not tracking with zero weight + }, + }, + } + }, + {}, + False, + 1, + marks=pytest.mark.xfail( + reason="No losses computed when all factors=0 and log_it=False", + strict=True, + ), + id="zero_factors_no_tracking", + ), ), ) def test_lm_head( @@ -222,8 +303,15 @@ def test_lm_head( prediction_heads: int, ): head_config = { - "cross_entropy_implementation": cross_entropy_impl, "normalization": {"type": "rms_norm"}, + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "implementation": cross_entropy_impl, + "factor": 1.0, + "log_it": True, + } + }, } config = GPTBaseModelConfig.from_dict( { @@ -280,19 +368,19 @@ def test_lm_head( AttentionKwargs.sequence_first: sequence_first, AttentionKwargs.grad_output: 1.0, } - if head_config.distillation_model is None: - target = torch.randint( - 0, - VOCAB_SIZE, - label_shape, - dtype=torch.int64, - device=distributed.device, - ) - if loss_mask is not None: - target *= loss_mask + # always set lm targets + target = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.int64, + device=distributed.device, + ) + if loss_mask is not None: + target *= loss_mask - kwargs[LanguageModelKwargs.labels] = target - else: + kwargs[LanguageModelKwargs.labels] = target + if head_config.distillation_model is not None: assert config.head.max_prediction_distance == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), @@ -349,11 +437,7 @@ def test_lm_head( logit_weight=ref_logit_weight, logit_scale_factor=head_config.logits_scale_factor, logit_z_loss=head_config.logit_z_loss, - distillation_loss_implementation=head_config.distillation_loss_implementation, - language_model_loss_factor=( - head_config.language_model_loss_factor if head_config.language_model_loss_factor is not None else 1.0 - ), - distillation_loss_factor=head_config.distillation_loss_factor, + losses=head_config.losses, ) # Prepare LM head inputs @@ -367,19 +451,15 @@ def test_lm_head( lm_head_loss_name = f"lm_head_loss_{prediction_distance}" if prediction_distance > 0 else "lm_head_loss" expected_loss_keys = {lm_head_loss_name} - if head._compute_lm_loss: - lm_loss_name_unscaled = ( - f"lm_loss_unscaled_{prediction_distance}" if prediction_distance > 0 else "lm_loss_unscaled" - ) - lm_loss_name = f"lm_loss_{prediction_distance}" if prediction_distance > 0 else "lm_loss" - expected_loss_keys.add(lm_loss_name_unscaled) - expected_loss_keys.add(lm_loss_name) + # Get expected loss names from the loss configs + for loss_name, loss_config in head._config.losses.items(): + if loss_config.log_it: + formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) + expected_loss_keys.add(formatted_name) + if ref_z_loss is not None: expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - if head._compute_distillation_loss: - expected_loss_keys.add("distillation_loss") - expected_loss_keys.add("distillation_loss_unscaled") Assert.eq( {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, From 282925c5bcd6f3b2648aa1cfd4d40bed4058a739 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 16:51:37 +0000 Subject: [PATCH 12/21] test --- tests/layers/test_lm_head.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 5835b667..6bdaf3f6 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -293,6 +293,32 @@ def _lm_head( ), id="zero_factors_no_tracking", ), + pytest.param( + { + "head": { + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 1.0, + "log_it": False, # not tracking with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, # not tracking with zero weight + }, + }, + } + }, + {}, + False, + 1, + marks=pytest.mark.xfail( + reason="Cannot track distillation loss without distillation model being set", + strict=True, + ), + id="track_distillation_without_model", + ), ), ) def test_lm_head( From 0f73ea23d62e43c41c45a9e755e9e3db38a3a5a3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 17:54:53 +0000 Subject: [PATCH 13/21] tests --- fast_llm/layers/language_model/config.py | 13 ++--- fast_llm/layers/language_model/head.py | 1 + .../layers/language_model/lm_head_losses.py | 47 +++++++++---------- tests/test_config.py | 1 + tests/utils/model_configs.py | 28 +++-------- 5 files changed, 35 insertions(+), 55 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 786d312d..411e98f4 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -209,17 +209,12 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead def _validate(self) -> None: - with self._set_implicit_default(): - if not self.losses: - self.losses = { - "lm_loss": LossConfig._from_dict({"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}) - } - - for loss_config in self.losses.values(): - if "dist" in loss_config.type: - assert self.distillation_model is not None, "Distillation loss requires a distillation model." + for loss_config in self.losses.values(): + if "dist" in loss_config.type: + assert self.distillation_model is not None, "Distillation loss requires a distillation model." super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both + # Note: Default loss is handled at runtime in head.py if losses dict is empty @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c47a87de..e1f30332 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -100,6 +100,7 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) + assert self._config.losses, "At least one loss must be configured." self._formatted_loss_names = { loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) for loss_name, loss_config in self._config.losses.items() diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index a231efa5..9fd94662 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -3,17 +3,16 @@ import logging import typing -import torch - from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - pass + import torch + + from fast_llm.core.distributed import ProcessGroup logger = logging.getLogger(__name__) @@ -32,13 +31,13 @@ def _format_name(name: str) -> str: @dataclasses.dataclass class Targets: - lm_target: torch.Tensor | None = None - dpo_target: torch.Tensor | None = None - loss_mask: torch.Tensor | None = None + lm_target: "torch.Tensor | None" = None + dpo_target: "torch.Tensor | None" = None + loss_mask: "torch.Tensor | None" = None chosen_spans: list[list[tuple[int, int]]] | None = None rejected_spans: list[list[tuple[int, int]]] | None = None - reference_model_logits: torch.Tensor | None = None - dpo_reference_model_logits: torch.Tensor | None = None + reference_model_logits: "torch.Tensor | None" = None + dpo_reference_model_logits: "torch.Tensor | None" = None def has_any_target(self) -> bool: return any(getattr(self, field.name) is not None for field in dataclasses.fields(self)) @@ -70,14 +69,14 @@ class LossConfig(Config): @abc.abstractmethod def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", target: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: @@ -124,14 +123,14 @@ class CrossEntropyLMLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import cross_entropy_forward_backward target = targets.lm_target @@ -176,13 +175,13 @@ class ForwardKLLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import forward_kl_forward_backward target = targets.reference_model_logits @@ -218,13 +217,13 @@ class ReverseKLLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import reverse_kl_forward_backward # Use distillation_target for KL losses @@ -261,12 +260,12 @@ class DPOLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.dpo import compute_dpo_loss return compute_dpo_loss( diff --git a/tests/test_config.py b/tests/test_config.py index 4020b6fb..8d6f3924 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -147,6 +147,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "normalization": {"implementation": "triton"}, }, "num_blocks": 12, + "head": {}, }, "hidden_size": 512, "tied_embedding_weight": False, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index f4e3ecea..3cadb4e2 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -240,7 +240,12 @@ def _update_and_add_testing_config( }, "num_blocks": 2, }, - "head": {"output_weight": init_1}, + "head": { + "output_weight": init_1, + "losses": { + "lm_loss": {"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}, + }, + }, "hidden_size": 256, "tied_embedding_weight": True, }, @@ -580,27 +585,6 @@ def _update_and_add_testing_config( skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), ) -_update_and_add_testing_config( - "mistral_distill_logits", - "mistral_reverse_kl", - updates={ - ("model", "base_model", "head", "distillation_loss_implementation"): "reverse_kl", - }, - megatron_args=None, - checkpoint_format=MistralCheckpointFormat, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, - ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, - ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, - ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, tp2, stp2, stp2_ce4 - }, - compare_factor=2, - # Modes not supported with reference models - skip_tests=("sdp", "ms", "pp"), -) - _update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", From fa85c415abd4481baba7ac9b9e037854e72cea82 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 22:27:28 +0000 Subject: [PATCH 14/21] wip --- fast_llm/functional/cross_entropy.py | 104 +++----------- fast_llm/layers/language_model/config.py | 4 +- fast_llm/layers/language_model/head.py | 13 +- .../layers/language_model/lm_head_losses.py | 30 ++-- tests/layers/test_lm_head.py | 132 +++--------------- tests/utils/model_configs.py | 4 +- 6 files changed, 55 insertions(+), 232 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index f534d8a7..06c85848 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -85,6 +85,7 @@ def _fused_cross_entropy_forward_backward( target_format: TargetFormat, group: ProcessGroup | None = None, teacher_softmax_temperature: float = 1.0, + return_target_entropy: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -158,6 +159,16 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.AVG, group=group) + if return_target_entropy and target_format == TargetFormat.logits: + # Compute teacher entropy + teacher_log_prob = torch.log(target + 1e-20) + target_entropy = -(target * teacher_log_prob).sum(dim=-1) + if loss_mask is not None: + target_entropy = target_entropy * loss_mask.squeeze(-1) + target_entropy = target_entropy.mean() + if group is not None: + all_reduce(target_entropy, op=ReduceOp.SUM, group=group) + return loss, grad, target_entropy return loss, grad @@ -362,78 +373,6 @@ def reverse_kl_forward_backward( return distillation_loss, distillation_grad -@torch.compile -def _forward_kl_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Forward KL: KL(p||q) where p=teacher, q=student. - This is reverse KL with roles swapped in the loss computation. - - Key insight: KL(p||q) = sum_i p_i * log(p_i/q_i) - = sum_i p_i * (log(p_i) - log(q_i)) - which is reverse KL with p and q swapped. - - However, we still need grad w.r.t. student logits, so gradient is different: - d/d(student_logits) KL(p||q) = student_probs - teacher_probs - """ - Assert.eq( - teacher_softmax_temperature, - 1, - msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel forward KL", - ) - Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel forward KL") - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - # Compute log softmax for both teacher and student - teacher_log_probs = distributed_log_softmax(target.float(), group=group) - student_log_probs = distributed_log_softmax(logits, group=group) - - teacher_probs = teacher_log_probs.exp() - # Forward KL: p * log(p/q) = p * (log_p - log_q) - log_ratio = teacher_log_probs - student_log_probs - del teacher_log_probs - - # Compute loss: sum over vocab of teacher_probs * log_ratio - loss_terms = (teacher_probs * log_ratio).sum(dim=-1) - del log_ratio - - if loss_mask is not None: - valid = loss_mask.to(loss_terms.dtype) - loss_terms = loss_terms * valid - valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) - loss = loss_terms.sum() - - if group is not None: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= valid_tokens - - if grad_output is not None: - # Gradient: d/d(student_logits) KL(p||q) = student_probs - teacher_probs - student_probs = student_log_probs.exp() - grad_base = student_probs - teacher_probs - del student_probs, teacher_probs, student_log_probs - - if loss_mask is not None: - grad_base.mul_(loss_mask.to(logits.dtype).unsqueeze(-1)) - - grad_base.mul_(grad_output / valid_tokens) - grad = grad_base.to(logits.dtype) - else: - grad = None - - return loss.detach_(), grad - - def forward_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -467,25 +406,20 @@ def forward_kl_forward_backward( loss: Forward KL divergence loss grad: Gradients w.r.t. logits """ - - if sequence_parallel_logits: - # TODO: see hybrid dev branch where it is implemented - raise NotImplementedError("Sequence-parallel forward KL is not implemented yet, set vocab_parallel true") - - Assert.eq(target_format, TargetFormat.logits, msg="Forward KL only supports logits format") + assert target_format == TargetFormat.logits, "Forward KL only supports logits format" Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - # TODO: implement fused? - distillation_loss, distillation_grad = _forward_kl_forward_backward( + distillation_loss, distillation_grad, teacher_entropy = _fused_cross_entropy_forward_backward( logits=logits, target=target, loss_mask=loss_mask, grad_output=grad_output, logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=teacher_softmax_temperature, + target_format=target_format, group=group, + teacher_softmax_temperature=teacher_softmax_temperature, + return_target_entropy=True, + **kwargs, ) + distillation_loss -= teacher_entropy + return distillation_loss, distillation_grad diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 411e98f4..e2ce6ae1 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -9,7 +9,7 @@ from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.lm_head_losses import LossConfig +from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -135,7 +135,7 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the final normalization layer.", hint=FieldHint.architecture, ) - losses: dict[str, LossConfig] = Field( + losses: dict[str, LanguageModelLossConfig] = Field( default_factory=dict, desc="A dictionary of loss names and their configurations.", hint=FieldHint.core, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index e1f30332..6ba45c24 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -375,7 +375,7 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.factor == 0.0 and not loss_config.log_it: + if loss_config.factor == 0.0: continue # losses are returned unscaled but the grads are already scaled # we log unscaled losses seperately and the scaled total loss @@ -391,7 +391,7 @@ def _logits_loss_forward_backward( ) loss_ = loss_unscaled_ * loss_config.factor * self._loss_coefficient - if losses is not None and loss_config.log_it: + if losses is not None: losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) if total_loss is None: @@ -438,11 +438,10 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) ) for loss_name, loss_config in self._config.losses.items(): - if loss_config.log_it: - loss_def: LossDef = loss_config.get_loss_def( - name=loss_name, count=count, prediction_distance=self._prediction_distance - ) - loss_defs.append(loss_def) + loss_def: LossDef = loss_config.get_loss_def( + name=loss_name, count=count, prediction_distance=self._prediction_distance + ) + loss_defs.append(loss_def) return loss_defs diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 9fd94662..3695954b 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -44,10 +44,10 @@ def has_any_target(self) -> bool: @config_class(registry=True) -class LossConfig(Config): +class LanguageModelLossConfig(Config): """ Losses canm register themselves - using @config_class(dynamic_type={LossConfig: "loss_type_name"}) + using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}) """ _name: typing.ClassVar[str] @@ -60,12 +60,6 @@ class LossConfig(Config): valid=check_field(Assert.geq, 0.0), ) - log_it: bool = Field( - default=True, - hint=FieldHint.optional, - desc="Whether to log this loss.", - ) - @abc.abstractmethod def compute_loss( self, @@ -90,10 +84,6 @@ def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | Non def _validate(self): Assert.geq(self.factor, 0.0) - if self.factor > 0.0: - with self._set_implicit_default(): - if "log_it" not in self._explicit_fields: - self.log_it = True super()._validate() def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: @@ -103,8 +93,8 @@ def get_formatted_name(self, name=None, prediction_distance: int | None = None) return name -@config_class(dynamic_type={LossConfig: "cross_entropy_lm_loss"}) -class CrossEntropyLMLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) +class CrossEntropyLMLossConfig(LanguageModelLossConfig): _name: typing.ClassVar[str] = "CE" _abstract: typing.ClassVar[bool] = False @@ -159,8 +149,8 @@ def compute_loss( ) -@config_class(dynamic_type={LossConfig: "fkl_dist"}) -class ForwardKLLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) +class ForwardKLLossConfig(LanguageModelLossConfig): """Forward KL divergence KL(p||q) for distillation (mode-covering).""" _name: typing.ClassVar[str] = "FwdKL" @@ -201,8 +191,8 @@ def compute_loss( ) -@config_class(dynamic_type={LossConfig: "revkl_dist"}) -class ReverseKLLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) +class ReverseKLLossConfig(LanguageModelLossConfig): """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" _name: typing.ClassVar[str] = "RevKL" @@ -244,8 +234,8 @@ def compute_loss( ) -@config_class(dynamic_type={LossConfig: "dpo"}) -class DPOLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) +class DPOLossConfig(LanguageModelLossConfig): """Direct Preference Optimization (DPO) loss for alignment.""" _name: typing.ClassVar[str] = "DPO" diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 6bdaf3f6..ddfc2fc1 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,7 +9,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.language_model.lm_head_losses import LossConfig +from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -69,7 +69,7 @@ def _lm_head( grad_output: float = 1.0, logit_scale_factor: float = 1.0, logit_z_loss=0.0, - losses: dict[str, LossConfig], + losses: dict[str, LanguageModelLossConfig], ): hidden = torch.rms_norm( input_.to(rms_weight.dtype), @@ -80,7 +80,7 @@ def _lm_head( logits = torch.nn.functional.linear(hidden, logit_weight).float() if "dist_loss" in losses: - if losses["dist_loss"].type == "revkl_dist": + if losses["dist_loss"].type == "reverse_kl_distillation": Assert.eq(logits.shape, target.shape) loss = _reverse_kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask @@ -89,7 +89,7 @@ def _lm_head( loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) # Return scaled loss return loss * losses["dist_loss"].factor, None - elif losses["dist_loss"].type == "fkl_dist": + elif losses["dist_loss"].type == "forward_kl_distillation": Assert.eq(logits.shape, target.shape) loss = _kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask @@ -137,14 +137,12 @@ def _lm_head( # "distillation_model": "distillation", # "losses": { # "lm_loss": { - # "type": "cross_entropy_lm_loss", + # "type": "cross_entropy", # "weight_scalor": 0.0, - # "log_it": False, # }, # "dist_loss": { # "type": "cross_entropy_dist", # TODO: Not implemented yet # "weight_scalor": 1.0, - # "log_it": True, # } # } # } @@ -153,87 +151,18 @@ def _lm_head( # False, # 1, # ), - ( - { - "head": { - "distillation_model": "distillation", - "losses": { - "lm_loss": { - "type": "cross_entropy_lm_loss", - "factor": 0.0, - "log_it": False, - }, - "dist_loss": { - "type": "revkl_dist", - "factor": 1.0, - "log_it": True, - }, - }, - } - }, - {}, - False, - 1, - ), - # Skip - CE distillation not implemented - # ( - # { - # "head": { - # "distillation_model": "distillation", - # "losses": { - # "lm_loss": { - # "type": "cross_entropy_lm_loss", - # "weight_scalor": 1.0, - # "log_it": True, - # }, - # "dist_loss": { - # "type": "cross_entropy_dist", # TODO - # "weight_scalor": 1.0, - # "log_it": True, - # } - # } - # } - # }, - # {}, - # True, - # 1, - # ), - ( - { - "head": { - "distillation_model": "distillation", - "losses": { - "lm_loss": { - "type": "cross_entropy_lm_loss", - "factor": 0.0, - "log_it": False, - }, - "dist_loss": { - "type": "revkl_dist", - "factor": 1.0, - "log_it": True, - }, - }, - } - }, - {}, - True, - 1, - ), pytest.param( { "head": { "distillation_model": "distillation", "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "factor": 0.0, - "log_it": True, # tracking even with zero weight }, "dist_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 1.0, - "log_it": True, }, }, } @@ -249,37 +178,12 @@ def _lm_head( "distillation_model": "distillation", "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "factor": 0.0, - "log_it": True, # tracking with zero weight }, "dist_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 0.0, - "log_it": True, # tracking with zero weight - }, - }, - } - }, - {}, - False, - 1, - id="track_both_zero_factors", - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "losses": { - "lm_loss": { - "type": "cross_entropy_lm_loss", - "factor": 0.0, - "log_it": False, # not tracking with zero weight - }, - "dist_loss": { - "type": "revkl_dist", - "factor": 0.0, - "log_it": False, # not tracking with zero weight }, }, } @@ -288,24 +192,22 @@ def _lm_head( False, 1, marks=pytest.mark.xfail( - reason="No losses computed when all factors=0 and log_it=False", + reason="Cannot track both losses with zero factor", strict=True, ), - id="zero_factors_no_tracking", + id="track_both_zero_factors", ), pytest.param( { "head": { "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "factor": 1.0, - "log_it": False, # not tracking with zero weight }, "dist_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 1.0, - "log_it": True, # not tracking with zero weight }, }, } @@ -332,10 +234,9 @@ def test_lm_head( "normalization": {"type": "rms_norm"}, "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "implementation": cross_entropy_impl, "factor": 1.0, - "log_it": True, } }, } @@ -480,9 +381,8 @@ def test_lm_head( # Get expected loss names from the loss configs for loss_name, loss_config in head._config.losses.items(): - if loss_config.log_it: - formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) - expected_loss_keys.add(formatted_name) + formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) + expected_loss_keys.add(formatted_name) if ref_z_loss is not None: expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 3cadb4e2..93c78b58 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -243,7 +243,7 @@ def _update_and_add_testing_config( "head": { "output_weight": init_1, "losses": { - "lm_loss": {"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}, + "lm_loss": {"type": "cross_entropy", "factor": 1.0}, }, }, "hidden_size": 256, @@ -559,7 +559,7 @@ def _update_and_add_testing_config( ("model", "base_model", "head", "distillation_model"): "teacher", ("model", "base_model", "head", "losses"): { "distillation_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 1.0, }, }, From 31cfb84dd2081c0d1c40f31dee20859105e50146 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 02:22:15 +0000 Subject: [PATCH 15/21] wip --- fast_llm/data/dataset/gpt/config.py | 1 - fast_llm/layers/language_model/config.py | 14 ++++++++++++-- fast_llm/layers/language_model/head.py | 2 +- tests/test_config.py | 8 +++++++- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 41a2fe7f..5e978ac2 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -65,7 +65,6 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." config = yaml.safe_load(self.path.open("r")) - Assert.eq(config.keys(), {"config", "metadata"}) if config.keys() == {"config", "metadata"}: # Newer format with metadata config = config["config"] diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e2ce6ae1..58e85f5d 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -209,12 +209,22 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead def _validate(self) -> None: + with self._set_implicit_default(): + if not self.losses: + if "losses" not in self._explicit_fields: + self.losses = { + "lm_loss": LanguageModelLossConfig._from_dict( + { + "type": "cross_entropy", + "factor": 1.0, + } + ) + } for loss_config in self.losses.values(): - if "dist" in loss_config.type: + if "distillation" in loss_config.type: assert self.distillation_model is not None, "Distillation loss requires a distillation model." super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both - # Note: Default loss is handled at runtime in head.py if losses dict is empty @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 6ba45c24..a67869f8 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -100,7 +100,7 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - assert self._config.losses, "At least one loss must be configured." + self._formatted_loss_names = { loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) for loss_name, loss_config in self._config.losses.items() diff --git a/tests/test_config.py b/tests/test_config.py index 8d6f3924..81137b58 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -147,14 +147,16 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "normalization": {"implementation": "triton"}, }, "num_blocks": 12, - "head": {}, }, + "head": {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, } else: expected_config["base_model"] = base_model_update + # added by default + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) @@ -297,3 +299,7 @@ def test_distributed_global_ranks(bdp: int, sdp: int, tp: int, pp: int, pipeline Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks}), world_size) Assert.eq(len(rank_breakdowns), world_size) + + +if __name__ == "__main__": + pytest.main([__file__]) From 24fe67bbebbdd9a8aa5ad1393b43250ced3b8629 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 15:43:26 +0000 Subject: [PATCH 16/21] no grad if factor 0 --- fast_llm/layers/language_model/head.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index a67869f8..50240f49 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -383,7 +383,9 @@ def _logits_loss_forward_backward( logits, targets, grad_output=( - grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None + (grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None) + if loss_config.factor != 0.0 + else None ), group=group, logits_scale_factor=self._config.logits_scale_factor, From 0e562e99198e8414b1c026d17cd3383c7acc2f55 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 17:00:00 +0000 Subject: [PATCH 17/21] addressed comments --- fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/language_model/head.py | 8 +++--- .../layers/language_model/lm_head_losses.py | 4 +-- tests/layers/test_lm_head.py | 26 +++++++++---------- tests/test_config.py | 4 +-- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 58e85f5d..4bd8a592 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -216,7 +216,7 @@ def _validate(self) -> None: "lm_loss": LanguageModelLossConfig._from_dict( { "type": "cross_entropy", - "factor": 1.0, + "weight": 1.0, } ) } diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 50240f49..40c09961 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -375,7 +375,7 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.factor == 0.0: + if loss_config.weight == 0.0: continue # losses are returned unscaled but the grads are already scaled # we log unscaled losses seperately and the scaled total loss @@ -383,15 +383,15 @@ def _logits_loss_forward_backward( logits, targets, grad_output=( - (grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None) - if loss_config.factor != 0.0 + (grad_output * self._loss_coefficient * loss_config.weight if grad_output is not None else None) + if loss_config.weight != 0.0 else None ), group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, ) - loss_ = loss_unscaled_ * loss_config.factor * self._loss_coefficient + loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient if losses is not None: losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 3695954b..dc367be6 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -53,7 +53,7 @@ class LanguageModelLossConfig(Config): _name: typing.ClassVar[str] _abstract: typing.ClassVar[bool] = True - factor: float = Field( + weight: float = Field( default=1.0, hint=FieldHint.core, desc="Weight for this loss in the total loss computation.", @@ -83,7 +83,7 @@ def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | Non ) def _validate(self): - Assert.geq(self.factor, 0.0) + Assert.geq(self.weight, 0.0) super()._validate() def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index ddfc2fc1..7f9e55b7 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -86,18 +86,18 @@ def _lm_head( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask ) # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].weight)) # Return scaled loss - return loss * losses["dist_loss"].factor, None + return loss * losses["dist_loss"].weight, None elif losses["dist_loss"].type == "forward_kl_distillation": Assert.eq(logits.shape, target.shape) loss = _kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask ) # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].weight)) # Return scaled loss - return loss * losses["dist_loss"].factor, None + return loss * losses["dist_loss"].weight, None if logit_scale_factor != 1.0: logits *= logit_scale_factor @@ -105,8 +105,8 @@ def _lm_head( # Language model loss (cross-entropy with hard labels) loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) # Apply language_model_loss_factor - loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].factor)) - return loss * losses["lm_loss"].factor, z_loss + loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].weight)) + return loss * losses["lm_loss"].weight, z_loss SEQUENCE_LENGTH = 200 @@ -158,11 +158,11 @@ def _lm_head( "losses": { "lm_loss": { "type": "cross_entropy", - "factor": 0.0, + "weight": 0.0, }, "dist_loss": { "type": "reverse_kl_distillation", - "factor": 1.0, + "weight": 1.0, }, }, } @@ -179,11 +179,11 @@ def _lm_head( "losses": { "lm_loss": { "type": "cross_entropy", - "factor": 0.0, + "weight": 0.0, }, "dist_loss": { "type": "reverse_kl_distillation", - "factor": 0.0, + "weight": 0.0, }, }, } @@ -203,11 +203,11 @@ def _lm_head( "losses": { "lm_loss": { "type": "cross_entropy", - "factor": 1.0, + "weight": 1.0, }, "dist_loss": { "type": "reverse_kl_distillation", - "factor": 1.0, + "weight": 1.0, }, }, } @@ -236,7 +236,7 @@ def test_lm_head( "lm_loss": { "type": "cross_entropy", "implementation": cross_entropy_impl, - "factor": 1.0, + "weight": 1.0, } }, } diff --git a/tests/test_config.py b/tests/test_config.py index 81137b58..3c6a76a3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -148,7 +148,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, - "head": {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}}, + "head": {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, @@ -156,7 +156,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): else: expected_config["base_model"] = base_model_update # added by default - expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}} + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) From 52c1c113d1fe32732b7bc2c666c0cfd6303abca8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 17:44:53 +0000 Subject: [PATCH 18/21] addressed comments --- fast_llm/functional/cross_entropy.py | 4 --- fast_llm/layers/language_model/head.py | 11 ++----- .../layers/language_model/lm_head_losses.py | 29 ++++++++++--------- tests/utils/model_configs.py | 2 +- 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 06c85848..03f7a88e 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -247,7 +247,6 @@ def _reverse_kl_forward_backward( group: ProcessGroup | None = None, logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. @@ -325,7 +324,6 @@ def reverse_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -383,7 +381,6 @@ def forward_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). @@ -418,7 +415,6 @@ def forward_kl_forward_backward( group=group, teacher_softmax_temperature=teacher_softmax_temperature, return_target_entropy=True, - **kwargs, ) distillation_loss -= teacher_entropy diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 40c09961..bce20c83 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -182,14 +182,10 @@ def _get_targets(self, kwargs: dict) -> Targets | None: dpo_target, reference_model_logits, loss_mask, - chosen_spans, - rejected_spans, dpo_reference_model_logits, - ) = (None, None, None, None, None, None, None) + ) = (None, None, None, None, None) if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) - chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) - rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) dpo_reference_model_logits = (kwargs.get(f"{self._config.dpo_reference_model}_logits"),) else: if self._config.distillation_model is not None: @@ -230,8 +226,6 @@ def _get_targets(self, kwargs: dict) -> Targets | None: dpo_target=dpo_target, lm_target=lm_target, loss_mask=loss_mask, - chosen_spans=chosen_spans, - rejected_spans=rejected_spans, reference_model_logits=reference_model_logits, dpo_reference_model_logits=dpo_reference_model_logits, ) @@ -302,8 +296,6 @@ def _logits_cross_entropy_forward_backward_split( dpo_target=dpo_target_, reference_model_logits=reference_model_logits_, loss_mask=loss_mask_, - chosen_spans=targets.chosen_spans, - rejected_spans=targets.rejected_spans, dpo_reference_model_logits=targets.dpo_reference_model_logits, ) loss_, grad_ = self._logits_loss_forward_backward( @@ -390,6 +382,7 @@ def _logits_loss_forward_backward( group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, + kwargs=kwargs, ) loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index dc367be6..4be129a2 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -34,8 +34,6 @@ class Targets: lm_target: "torch.Tensor | None" = None dpo_target: "torch.Tensor | None" = None loss_mask: "torch.Tensor | None" = None - chosen_spans: list[list[tuple[int, int]]] | None = None - rejected_spans: list[list[tuple[int, int]]] | None = None reference_model_logits: "torch.Tensor | None" = None dpo_reference_model_logits: "torch.Tensor | None" = None @@ -64,12 +62,12 @@ class LanguageModelLossConfig(Config): def compute_loss( self, logits: "torch.Tensor", - target: Targets, + targets: Targets, grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, - **kwargs, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass @@ -119,7 +117,7 @@ def compute_loss( group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, - **kwargs, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import cross_entropy_forward_backward @@ -145,7 +143,6 @@ def compute_loss( logits_scale_factor=logits_scale_factor, teacher_softmax_temperature=self.teacher_softmax_temperature, target_format=TargetFormat.labels, - **kwargs, ) @@ -170,7 +167,8 @@ def compute_loss( grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, - **kwargs, + vocab_parallel: bool = False, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import forward_kl_forward_backward @@ -187,7 +185,6 @@ def compute_loss( logits_scale_factor=logits_scale_factor, teacher_softmax_temperature=self.teacher_softmax_temperature, target_format=TargetFormat.logits, - **kwargs, ) @@ -212,7 +209,8 @@ def compute_loss( grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, - **kwargs, + vocab_parallel: bool = False, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import reverse_kl_forward_backward @@ -230,7 +228,6 @@ def compute_loss( logits_scale_factor=logits_scale_factor, teacher_softmax_temperature=self.teacher_softmax_temperature, target_format=TargetFormat.logits, - **kwargs, ) @@ -254,16 +251,22 @@ def compute_loss( targets: Targets, grad_output: float | None = None, group: "ProcessGroup" = None, - **kwargs, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.dpo import compute_dpo_loss + from fast_llm.layers.language_model.config import LanguageModelKwargs + + chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) + rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) return compute_dpo_loss( logits=logits, targets=targets.dpo_target, reference_model_logits=targets.dpo_reference_model_logits, - chosen_spans=targets.chosen_spans, - rejected_spans=targets.rejected_spans, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, beta=self.beta, grad_output=grad_output, ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6cda07ad..f3d4659c 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -247,7 +247,7 @@ def _update_and_add_testing_config( "head": { "output_weight": init_1, "losses": { - "lm_loss": {"type": "cross_entropy", "factor": 1.0}, + "lm_loss": {"type": "cross_entropy", "weight": 1.0}, }, }, "hidden_size": 256, From 406d0a2eaf355488a699220ad4198371585effa2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 30 Dec 2025 20:13:50 +0000 Subject: [PATCH 19/21] Removed Targets class Removed the targets, class, moved tragets processing to losses, made loss masks more explicit --- fast_llm/layers/language_model/config.py | 17 +- fast_llm/layers/language_model/embedding.py | 3 +- fast_llm/layers/language_model/head.py | 139 ++++++----------- fast_llm/layers/language_model/kwargs.py | 23 +++ .../layers/language_model/lm_head_losses.py | 147 +++++++++++++----- fast_llm/models/gpt/model.py | 2 +- fast_llm/models/multimodal/model.py | 2 +- tests/layers/test_lm_head.py | 3 +- 8 files changed, 185 insertions(+), 151 deletions(-) create mode 100644 fast_llm/layers/language_model/kwargs.py diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4bd8a592..9f6cbf4c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig +from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -19,21 +19,6 @@ from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction -class LanguageModelKwargs(BlockKwargs): - token_ids = "token_ids" - position_ids = "position_ids" - token_map = "token_map" - sample_map = "sample_map" - embedding_map = "embedding_map" - # TODO: These are generic - labels = "labels" - phase = "phase" - chosen_spans = "chosen_spans" - rejected_spans = "rejected_spans" - loss_mask = "loss_mask" - mask_inputs = "mask_inputs" - - @config_class() class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 93850d24..fda5e338 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -10,7 +10,8 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index bce20c83..27b090c1 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,9 +22,9 @@ LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, - LanguageModelKwargs, ) -from fast_llm.layers.language_model.lm_head_losses import Targets, _format_name +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.lm_head_losses import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -101,10 +101,12 @@ def __init__( peft=self._peft, ) - self._formatted_loss_names = { - loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) - for loss_name, loss_config in self._config.losses.items() - } + self._formatted_loss_names = {} + for loss_name, loss_config in self._config.losses.items(): + if loss_config.weight > 0.0: + self._formatted_loss_names[loss_name] = loss_config.get_formatted_name( + loss_name, self._prediction_distance + ) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -154,6 +156,12 @@ def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: targets = self._get_targets(kwargs) + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + if loss_mask is not None: + loss_mask = loss_mask.flatten() + if self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) + input_ = input_.detach().requires_grad_(do_grad := targets is not None and self.training) with torch.enable_grad(): ln_output = self.final_norm(input_) @@ -167,7 +175,7 @@ def _forward_backward( output_weights = self.output_weights loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), targets, output_weights, grad_output, kwargs, losses + ln_output.detach(), targets, loss_mask, output_weights, grad_output, kwargs, losses ) if do_grad: @@ -176,62 +184,20 @@ def _forward_backward( else: return loss, None - def _get_targets(self, kwargs: dict) -> Targets | None: - ( - lm_target, - dpo_target, - reference_model_logits, - loss_mask, - dpo_reference_model_logits, - ) = (None, None, None, None, None) - if self._config.enable_dpo: - dpo_target = kwargs.get(LanguageModelKwargs.labels) - dpo_reference_model_logits = (kwargs.get(f"{self._config.dpo_reference_model}_logits"),) - else: - if self._config.distillation_model is not None: - # Target is reference model logits. - reference_model_logits = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) - if loss_mask is not None: - loss_mask = loss_mask.flatten() - - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice( - self._prediction_distance, self._prediction_distance + lm_target_sequence_length - ) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - - if self._sequence_parallel_logits: - if dpo_target is not None: - dpo_target = split_op(dpo_target, self._parallel_dim.group, 0) - if lm_target is not None: - lm_target = split_op(lm_target, self._parallel_dim.group, 0) - if loss_mask is not None: - loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) - if reference_model_logits is not None: - reference_model_logits = split_op(reference_model_logits, self._parallel_dim.group, 0) - - targets = Targets( - dpo_target=dpo_target, - lm_target=lm_target, - loss_mask=loss_mask, - reference_model_logits=reference_model_logits, - dpo_reference_model_logits=dpo_reference_model_logits, - ) - - # Return None if no targets are set - if not targets.has_any_target(): + def _get_targets(self, kwargs: dict) -> dict | None: + targets = {} + for loss_config in self._config.losses.values(): + if loss_config.weight == 0.0: + continue + loss_targets = loss_config.extract_targets_from_global_kwargs( + kwargs, + prediction_distance=self._prediction_distance, + prediction_heads=self._prediction_heads, + head_config=self._config, + sequence_parallel_logits=self._sequence_parallel_logits, + ) + targets.update({k: v for k, v in loss_targets.items() if v is not None}) + if len(targets) == 0: return None return targets @@ -241,15 +207,16 @@ def get_output_weights(self) -> list[torch.Tensor]: def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, - targets: Targets | None, + targets: dict[str, "torch.Tensor"] | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._config.cross_entropy_splits is None or targets is None: + if self._config.cross_entropy_splits is None: loss, logit_input_grad = self._logits_loss_forward_backward( - input_, targets, weight, grad_output, kwargs, losses + input_, targets, loss_mask, weight, grad_output, kwargs, losses ) if targets is None: # TODO: Make a proper way of returning the model output. @@ -273,34 +240,28 @@ def _logits_cross_entropy_forward_backward_split( else: logit_input_grad = None - # Extract target tensors for splitting (keep same order as original tuple) - target_tensors = [ - targets.lm_target, - targets.dpo_target, - targets.reference_model_logits, - targets.loss_mask, - ] split_size = div( - get_unique(target.size(0) for target in target_tensors if target is not None), + get_unique(target.size(0) for target in targets.values() if target is not None), self._config.cross_entropy_splits, ) tensors_split = [ [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) - for tensor in [logit_input, *target_tensors, logit_input_grad] + for tensor in [logit_input, loss_mask, logit_input_grad] ] - for logit_input_, lm_target_, dpo_target_, reference_model_logits_, loss_mask_, logit_input_grad_ in zip( - *tensors_split, strict=True - ): - targets_ = Targets( - lm_target=lm_target_, - dpo_target=dpo_target_, - reference_model_logits=reference_model_logits_, - loss_mask=loss_mask_, - dpo_reference_model_logits=targets.dpo_reference_model_logits, + target_split = { + name: ( + [None] * self._config.cross_entropy_splits + if targets[name] is None + else targets[name].split(split_size) ) + for name in targets + } + + for i, (logit_input_, loss_mask_, logit_input_grad_) in enumerate(zip(*tensors_split, strict=True)): loss_, grad_ = self._logits_loss_forward_backward( logit_input_, - targets_, + {name: target_split[name][i] for name in target_split}, + loss_mask_, weight, grad_output, kwargs, @@ -323,7 +284,8 @@ def _logits_cross_entropy_forward_backward_split( def _logits_loss_forward_backward( self, input_: torch.Tensor, - targets: Targets | None, + targets: dict[str, "torch.Tensor"] | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -370,10 +332,9 @@ def _logits_loss_forward_backward( if loss_config.weight == 0.0: continue # losses are returned unscaled but the grads are already scaled - # we log unscaled losses seperately and the scaled total loss loss_unscaled_, grad_ = loss_config.compute_loss( logits, - targets, + loss_mask, grad_output=( (grad_output * self._loss_coefficient * loss_config.weight if grad_output is not None else None) if loss_config.weight != 0.0 @@ -382,7 +343,7 @@ def _logits_loss_forward_backward( group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, - kwargs=kwargs, + kwargs={**kwargs, **targets}, ) loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient diff --git a/fast_llm/layers/language_model/kwargs.py b/fast_llm/layers/language_model/kwargs.py new file mode 100644 index 00000000..4f620388 --- /dev/null +++ b/fast_llm/layers/language_model/kwargs.py @@ -0,0 +1,23 @@ +from fast_llm.layers.block.config import BlockKwargs + + +class TargetsKwargs: + lm_target = "preprocessed_lm_target" + dpo_target = "preprocessed_dpo_target" + reference_model_logits = "reference_model_logits" + dpo_reference_model_logits = "dpo_reference_model_logits" + + +class LanguageModelKwargs(BlockKwargs): + token_ids = "token_ids" + position_ids = "position_ids" + token_map = "token_map" + sample_map = "sample_map" + embedding_map = "embedding_map" + # TODO: These are generic + labels = "labels" + phase = "phase" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" + loss_mask = "loss_mask" + mask_inputs = "mask_inputs" diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 4be129a2..088e5504 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -1,18 +1,20 @@ import abc -import dataclasses import logging import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.core.ops import split_op from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs, TargetsKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: import torch from fast_llm.core.distributed import ProcessGroup + from fast_llm.layers.language_model.config import LanguageModelHeadConfig logger = logging.getLogger(__name__) @@ -29,23 +31,10 @@ def _format_name(name: str) -> str: return name.replace("_", " ") -@dataclasses.dataclass -class Targets: - lm_target: "torch.Tensor | None" = None - dpo_target: "torch.Tensor | None" = None - loss_mask: "torch.Tensor | None" = None - reference_model_logits: "torch.Tensor | None" = None - dpo_reference_model_logits: "torch.Tensor | None" = None - - def has_any_target(self) -> bool: - return any(getattr(self, field.name) is not None for field in dataclasses.fields(self)) - - @config_class(registry=True) class LanguageModelLossConfig(Config): """ - Losses canm register themselves - using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}) + Losses can register themselves using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}). """ _name: typing.ClassVar[str] @@ -62,7 +51,7 @@ class LanguageModelLossConfig(Config): def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -90,6 +79,18 @@ def get_formatted_name(self, name=None, prediction_distance: int | None = None) name = f"{name}_{prediction_distance}" return name + @abc.abstractmethod + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + pass + @config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) class CrossEntropyLMLossConfig(LanguageModelLossConfig): @@ -109,10 +110,40 @@ class CrossEntropyLMLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + lm_target = kwargs.get(LanguageModelKwargs.labels) + if lm_target is not None: + # MTP: Shift the labels + lm_target_sequence_length = ( + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads + ) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + lm_target_slice = slice(prediction_distance, prediction_distance + lm_target_sequence_length) + lm_target = ( + lm_target[lm_target_slice] + if kwargs[LanguageModelKwargs.sequence_first] + else lm_target[:, lm_target_slice] + ).flatten() + if sequence_parallel_logits: + lm_target = split_op(lm_target, group, 0) + return {TargetsKwargs.lm_target: lm_target} + def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -121,9 +152,7 @@ def compute_loss( ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import cross_entropy_forward_backward - target = targets.lm_target - if target is None: - raise ValueError("CrossEntropyLoss requires lm_target to be set in Targets") + target = kwargs.get(TargetsKwargs.lm_target) implementation = self.implementation if implementation == CrossEntropyImpl.auto: if vocab_parallel: @@ -160,10 +189,29 @@ class ForwardKLLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{head_config.distillation_model}_logits") + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + return {TargetsKwargs.reference_model_logits: reference_model_logits} + def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -172,14 +220,12 @@ def compute_loss( ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import forward_kl_forward_backward - target = targets.reference_model_logits - if target is None: - raise ValueError("ForwardKLLoss requires distillation_target to be set in Targets") + target = kwargs.get(TargetsKwargs.reference_model_logits) return forward_kl_forward_backward( logits=logits.flatten(0, -2), target=target, - loss_mask=targets.loss_mask, + loss_mask=loss_mask, grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, @@ -189,23 +235,16 @@ def compute_loss( @config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) -class ReverseKLLossConfig(LanguageModelLossConfig): +class ReverseKLLossConfig(ForwardKLLossConfig): """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" _name: typing.ClassVar[str] = "RevKL" _abstract: typing.ClassVar[bool] = False - teacher_softmax_temperature: float = Field( - default=1.0, - hint=FieldHint.optional, - desc="Temperature for teacher softmax.", - valid=check_field(Assert.gt, 0.0), - ) - def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -215,14 +254,12 @@ def compute_loss( from fast_llm.functional.cross_entropy import reverse_kl_forward_backward # Use distillation_target for KL losses - target = targets.reference_model_logits - if target is None: - raise ValueError("ReverseKLLoss requires distillation_target to be set in Targets") + target = kwargs.get(TargetsKwargs.reference_model_logits) return reverse_kl_forward_backward( logits=logits.flatten(0, -2), target=target, - loss_mask=targets.loss_mask, + loss_mask=loss_mask, grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, @@ -245,10 +282,35 @@ class DPOLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{head_config.dpo_reference_model}_logits") + dpo_target = kwargs.get(LanguageModelKwargs.labels) + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + if dpo_target is not None: + dpo_target = split_op(dpo_target, group, 0) + return { + TargetsKwargs.dpo_reference_model_logits: reference_model_logits, + TargetsKwargs.dpo_target: dpo_target, + } + def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -256,15 +318,16 @@ def compute_loss( kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.dpo import compute_dpo_loss - from fast_llm.layers.language_model.config import LanguageModelKwargs + dpo_target = kwargs.get(TargetsKwargs.dpo_target) + dpo_reference_model_logits = kwargs.get(TargetsKwargs.dpo_reference_model_logits) chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) return compute_dpo_loss( logits=logits, - targets=targets.dpo_target, - reference_model_logits=targets.dpo_reference_model_logits, + targets=dpo_target, + reference_model_logits=dpo_reference_model_logits, chosen_spans=chosen_spans, rejected_spans=rejected_spans, beta=self.beta, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2f43d1e4..846c6564 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -12,7 +12,7 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 890d5760..88da79e6 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.vision.config import VisionKwargs from fast_llm.layers.vision.vision_encoder import VisionMultiModalModel from fast_llm.models.gpt.config import GPTBatchConfig diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 7f9e55b7..ed639db9 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,8 +7,9 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert From f25380a191fd53bdc0427bc3592c3a026ad3fd22 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 30 Dec 2025 20:39:22 +0000 Subject: [PATCH 20/21] fixes --- fast_llm/layers/language_model/head.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 27b090c1..cb2312d7 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -195,6 +195,7 @@ def _get_targets(self, kwargs: dict) -> dict | None: prediction_heads=self._prediction_heads, head_config=self._config, sequence_parallel_logits=self._sequence_parallel_logits, + group=self._parallel_dim.group, ) targets.update({k: v for k, v in loss_targets.items() if v is not None}) if len(targets) == 0: @@ -240,8 +241,14 @@ def _logits_cross_entropy_forward_backward_split( else: logit_input_grad = None + # Collect all tensors that need to be split to determine the split size + tensors_to_check = [logit_input] + if loss_mask is not None: + tensors_to_check.append(loss_mask) + tensors_to_check.extend(target for target in targets.values() if target is not None) + split_size = div( - get_unique(target.size(0) for target in targets.values() if target is not None), + get_unique(tensor.size(0) for tensor in tensors_to_check), self._config.cross_entropy_splits, ) tensors_split = [ From 8adb7ddb9da22eba3f9a4e8a3cbff0e86ca2f214 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 30 Dec 2025 20:51:52 +0000 Subject: [PATCH 21/21] imports --- .../layers/language_model/lm_head_losses.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 088e5504..f6e69b4f 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -3,7 +3,6 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.core.ops import split_op from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig @@ -137,6 +136,8 @@ def extract_targets_from_global_kwargs( else lm_target[:, lm_target_slice] ).flatten() if sequence_parallel_logits: + from fast_llm.core.ops import split_op + lm_target = split_op(lm_target, group, 0) return {TargetsKwargs.lm_target: lm_target} @@ -205,6 +206,8 @@ def extract_targets_from_global_kwargs( if reference_model_logits is not None: reference_model_logits = reference_model_logits.flatten(0, -2) if sequence_parallel_logits: + from fast_llm.core.ops import split_op + reference_model_logits = split_op(reference_model_logits, group, 0) return {TargetsKwargs.reference_model_logits: reference_model_logits} @@ -296,12 +299,15 @@ def extract_targets_from_global_kwargs( reference_model_logits = kwargs.get(f"{head_config.dpo_reference_model}_logits") dpo_target = kwargs.get(LanguageModelKwargs.labels) - if reference_model_logits is not None: - reference_model_logits = reference_model_logits.flatten(0, -2) - if sequence_parallel_logits: - reference_model_logits = split_op(reference_model_logits, group, 0) - if dpo_target is not None: - dpo_target = split_op(dpo_target, group, 0) + if reference_model_logits is not None or dpo_target is not None: + from fast_llm.core.ops import split_op + + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + if dpo_target is not None: + dpo_target = split_op(dpo_target, group, 0) return { TargetsKwargs.dpo_reference_model_logits: reference_model_logits, TargetsKwargs.dpo_target: dpo_target,