diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 41a2fe7f..5e978ac2 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -65,7 +65,6 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." config = yaml.safe_load(self.path.open("r")) - Assert.eq(config.keys(), {"config", "metadata"}) if config.keys() == {"config", "metadata"}: # Newer format with metadata config = config["config"] diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 4cfc3b61..511c2d9f 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -100,11 +100,6 @@ class CrossEntropyImpl(str, enum.Enum): triton = "triton" -class DistillationLossImpl(str, enum.Enum): - reverse_kl = "reverse_kl" - cross_entropy = "cross_entropy" - - class TargetFormat(enum.StrEnum): labels = "labels" logits = "logits" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8c9ea939..03f7a88e 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -85,6 +85,7 @@ def _fused_cross_entropy_forward_backward( target_format: TargetFormat, group: ProcessGroup | None = None, teacher_softmax_temperature: float = 1.0, + return_target_entropy: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -158,6 +159,16 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.AVG, group=group) + if return_target_entropy and target_format == TargetFormat.logits: + # Compute teacher entropy + teacher_log_prob = torch.log(target + 1e-20) + target_entropy = -(target * teacher_log_prob).sum(dim=-1) + if loss_mask is not None: + target_entropy = target_entropy * loss_mask.squeeze(-1) + target_entropy = target_entropy.mean() + if group is not None: + all_reduce(target_entropy, op=ReduceOp.SUM, group=group) + return loss, grad, target_entropy return loss, grad @@ -236,7 +247,6 @@ def _reverse_kl_forward_backward( group: ProcessGroup | None = None, logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. @@ -359,3 +369,53 @@ def reverse_kl_forward_backward( group=group, ) return distillation_loss, distillation_grad + + +def forward_kl_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + teacher_softmax_temperature: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, + sequence_parallel_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). + This is mode-covering (vs. mode-seeking for reverse KL) and useful for: + - Encouraging the model to cover all modes of the target distribution + - Spreading probability mass broadly across the target support + - Standard distillation scenarios where you want to match the full teacher distribution + + Key differences from reverse KL: + - Forward KL: KL(p||q) = mode-covering (spreads mass broadly) + - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) + + Takes: + logits: [BxS, V] or [B, S, V], where V is local vocab size + target: [BxS, V] or [B, S, V] (logits format) + loss_mask: [BxS] or [B, S] or None + ... + + Returns: + loss: Forward KL divergence loss + grad: Gradients w.r.t. logits + """ + assert target_format == TargetFormat.logits, "Forward KL only supports logits format" + Assert.eq(target.shape, logits.shape) + distillation_loss, distillation_grad, teacher_entropy = _fused_cross_entropy_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + group=group, + teacher_softmax_temperature=teacher_softmax_temperature, + return_target_entropy=True, + ) + distillation_loss -= teacher_entropy + + return distillation_loss, distillation_grad diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 53dac289..9f6cbf4c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,11 +5,11 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig +from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -19,21 +19,6 @@ from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction -class LanguageModelKwargs(BlockKwargs): - token_ids = "token_ids" - position_ids = "position_ids" - token_map = "token_map" - sample_map = "sample_map" - embedding_map = "embedding_map" - # TODO: These are generic - labels = "labels" - phase = "phase" - chosen_spans = "chosen_spans" - rejected_spans = "rejected_spans" - loss_mask = "loss_mask" - mask_inputs = "mask_inputs" - - @config_class() class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False @@ -135,44 +120,22 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the final normalization layer.", hint=FieldHint.architecture, ) + losses: dict[str, LanguageModelLossConfig] = Field( + default_factory=dict, + desc="A dictionary of loss names and their configurations.", + hint=FieldHint.core, + ) # TODO: Cleanup output_weight: ParameterConfig = Field( desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - cross_entropy_implementation: CrossEntropyImpl = Field( - default=CrossEntropyImpl.auto, - desc="Implementation for the cross-entropy computation.", - hint=FieldHint.performance, - ) - distillation_loss_implementation: DistillationLossImpl = Field( - default=DistillationLossImpl.cross_entropy, - desc="Implementation for the distillation cross-entropy computation.", - hint=FieldHint.performance, - ) cross_entropy_splits: int | None = Field( default=None, desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - logit_z_loss: float = Field( - default=0.0, - desc="Regularize the logits with Z-loss.", - doc="We recommend 1e-4 for stability, as used for training PaLM.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - language_model_loss_factor: float = Field( - default=None, - desc="Factor to scale the language modeling loss by when using distillation.", - hint=FieldHint.feature, - ) - distillation_loss_factor: float = Field( - default=1.0, - desc="Factor to scale the distillation loss by when using distillation.", - hint=FieldHint.feature, - ) logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -181,10 +144,10 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - teacher_softmax_temperature: float = Field( - default=1.0, - desc="Divides distillation target logits by this factor.", - doc="Divides distillation target logits by this factor.", + logit_z_loss: float = Field( + default=0.0, + desc="Regularize the logits with Z-loss.", + doc="We recommend 1e-4 for stability, as used for training PaLM.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) @@ -193,11 +156,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Name of the reference model to use for dpo.", hint=FieldHint.feature, ) - dpo_beta: float | None = Field( - default=1.0, - desc="Beta value for DPO loss.", - hint=FieldHint.feature, - ) distillation_model: str | None = Field( default=None, desc="Name of the reference model to use for knowledge distillation." @@ -237,11 +195,19 @@ def layer_class(self) -> "type[LanguageModelHead]": def _validate(self) -> None: with self._set_implicit_default(): - if self.language_model_loss_factor is None: - if self.distillation_model is None: - self.language_model_loss_factor = 1.0 - else: - self.language_model_loss_factor = 0.0 + if not self.losses: + if "losses" not in self._explicit_fields: + self.losses = { + "lm_loss": LanguageModelLossConfig._from_dict( + { + "type": "cross_entropy", + "weight": 1.0, + } + ) + } + for loss_config in self.losses.values(): + if "distillation" in loss_config.type: + assert self.distillation_model is not None, "Distillation loss requires a distillation model." super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 93850d24..fda5e338 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -10,7 +10,8 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b1d0c2ac..cb2312d7 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -13,9 +13,6 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward -from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames @@ -25,8 +22,9 @@ LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, - LanguageModelKwargs, ) +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.lm_head_losses import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -87,16 +85,6 @@ def __init__( if self._config.cross_entropy_splits is not None and self._sequence_parallel: assert not self._vocab_parallel - if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_implementation - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._vocab_parallel: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused - self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) self.final_norm = self._config.normalization.get_layer( @@ -113,6 +101,13 @@ def __init__( peft=self._peft, ) + self._formatted_loss_names = {} + for loss_name, loss_config in self._config.losses.items(): + if loss_config.weight > 0.0: + self._formatted_loss_names[loss_name] = loss_config.get_formatted_name( + loss_name, self._prediction_distance + ) + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: @@ -137,8 +132,6 @@ def forward( # TODO: Drop autograd entirely. # TODO: Skip cross-entropy backward if not needed. language_model_loss = self._forward(input_, kwargs, losses) - if losses is not None and language_model_loss is not None: - losses[self._loss_name].append(language_model_loss.detach()) # TODO: Return the model output when needed. if self._is_last_head: # Last head should return the loss for backward. @@ -163,6 +156,12 @@ def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: targets = self._get_targets(kwargs) + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + if loss_mask is not None: + loss_mask = loss_mask.flatten() + if self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) + input_ = input_.detach().requires_grad_(do_grad := targets is not None and self.training) with torch.enable_grad(): ln_output = self.final_norm(input_) @@ -176,7 +175,7 @@ def _forward_backward( output_weights = self.output_weights loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), targets, output_weights, grad_output, kwargs, losses + ln_output.detach(), targets, loss_mask, output_weights, grad_output, kwargs, losses ) if do_grad: @@ -185,52 +184,22 @@ def _forward_backward( else: return loss, None - def _get_targets( - self, kwargs: dict - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: - # Loss mask for distillation. (Labels are already masked.) - if self._config.enable_dpo: - dpo_target = kwargs.get(LanguageModelKwargs.labels) - lm_target = None - distillation_target = None - loss_mask = None - else: - dpo_target = None - if self._config.distillation_model is None: - distillation_target, loss_mask = None, None - else: - # Target is reference model logits. - distillation_target = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) - if loss_mask is not None: - loss_mask = loss_mask.flatten() - - if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice( - self._prediction_distance, self._prediction_distance + lm_target_sequence_length - ) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - else: - lm_target = None - - targets = (dpo_target, lm_target, distillation_target, loss_mask) - if self._sequence_parallel_logits: - targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] - if not any(target is not None for target in targets): - # Simplify so we don't have to check every time. - targets = None + def _get_targets(self, kwargs: dict) -> dict | None: + targets = {} + for loss_config in self._config.losses.values(): + if loss_config.weight == 0.0: + continue + loss_targets = loss_config.extract_targets_from_global_kwargs( + kwargs, + prediction_distance=self._prediction_distance, + prediction_heads=self._prediction_heads, + head_config=self._config, + sequence_parallel_logits=self._sequence_parallel_logits, + group=self._parallel_dim.group, + ) + targets.update({k: v for k, v in loss_targets.items() if v is not None}) + if len(targets) == 0: + return None return targets def get_output_weights(self) -> list[torch.Tensor]: @@ -239,15 +208,16 @@ def get_output_weights(self) -> list[torch.Tensor]: def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None, + targets: dict[str, "torch.Tensor"] | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._config.cross_entropy_splits is None or targets is None: - loss, logit_input_grad = self._logits_cross_entropy_forward_backward( - input_, targets, weight, grad_output, kwargs, losses + if self._config.cross_entropy_splits is None: + loss, logit_input_grad = self._logits_loss_forward_backward( + input_, targets, loss_mask, weight, grad_output, kwargs, losses ) if targets is None: # TODO: Make a proper way of returning the model output. @@ -270,18 +240,35 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad = torch.empty_like(logit_input) else: logit_input_grad = None + + # Collect all tensors that need to be split to determine the split size + tensors_to_check = [logit_input] + if loss_mask is not None: + tensors_to_check.append(loss_mask) + tensors_to_check.extend(target for target in targets.values() if target is not None) + split_size = div( - get_unique(target.size(0) for target in targets if target is not None), + get_unique(tensor.size(0) for tensor in tensors_to_check), self._config.cross_entropy_splits, ) tensors_split = [ [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) - for tensor in [logit_input, *targets, logit_input_grad] + for tensor in [logit_input, loss_mask, logit_input_grad] ] - for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): - loss_, grad_ = self._logits_cross_entropy_forward_backward( + target_split = { + name: ( + [None] * self._config.cross_entropy_splits + if targets[name] is None + else targets[name].split(split_size) + ) + for name in targets + } + + for i, (logit_input_, loss_mask_, logit_input_grad_) in enumerate(zip(*tensors_split, strict=True)): + loss_, grad_ = self._logits_loss_forward_backward( logit_input_, - targets_, + {name: target_split[name][i] for name in target_split}, + loss_mask_, weight, grad_output, kwargs, @@ -301,10 +288,11 @@ def _logits_cross_entropy_forward_backward_split( all_reduce(loss, group=self._parallel_dim.group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None - def _logits_cross_entropy_forward_backward( + def _logits_loss_forward_backward( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None], + targets: dict[str, "torch.Tensor"] | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -319,6 +307,7 @@ def _logits_cross_entropy_forward_backward( sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) + # TODO: also move to lm_head_losses? if self._config.logit_z_loss > 0.0: logits = z_loss( logits, @@ -344,92 +333,52 @@ def _logits_cross_entropy_forward_backward( if targets is None: return logits * self._config.logits_scale_factor, None - dpo_target, lm_target, distillation_target, loss_mask = targets - if dpo_target is not None: - dpo_loss, dpo_grad = compute_dpo_loss( + total_loss, grad = None, None + for loss_name, loss_config in self._config.losses.items(): + if loss_config.weight == 0.0: + continue + # losses are returned unscaled but the grads are already scaled + loss_unscaled_, grad_ = loss_config.compute_loss( logits, - dpo_target, - kwargs.get(f"{self._config.dpo_reference_model}_logits"), - kwargs[LanguageModelKwargs.chosen_spans], - kwargs[LanguageModelKwargs.rejected_spans], - self._config.dpo_beta, - grad_output * self._loss_coefficient, - ) - else: - dpo_loss, dpo_grad = None, None - - if lm_target is not None: - lm_loss, lm_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - lm_target, - None, + loss_mask, + grad_output=( + (grad_output * self._loss_coefficient * loss_config.weight if grad_output is not None else None) + if loss_config.weight != 0.0 + else None + ), group=group, - grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, - implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.labels, + vocab_parallel=self._vocab_parallel, + kwargs={**kwargs, **targets}, ) - lm_loss = lm_loss * self._config.language_model_loss_factor - else: - lm_loss, lm_grad = None, None - - if distillation_target is not None and self._config.distillation_loss_factor > 0.0: - if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: - distillation_loss, distillation_grad = reverse_kl_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, - group=group, - logits_scale_factor=self._config.logits_scale_factor, - teacher_softmax_temperature=self._config.teacher_softmax_temperature, - target_format=( - TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits - ), - sequence_parallel_logits=self._sequence_parallel_logits, - ) + loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient - elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: - distillation_loss, distillation_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - group=group, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.logits, - ) - else: - raise ValueError( - f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" - ) - distillation_loss = distillation_loss * self._config.distillation_loss_factor - else: - distillation_loss, distillation_grad = None, None + if losses is not None: + losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) - # TODO: de-allocate earlier. - del logits + if total_loss is None: + total_loss = loss_ + else: + total_loss = total_loss + loss_ - # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) + if grad_ is not None: + if grad is None: + grad = grad_ + else: + grad = grad + grad_ - # TODO: Return individual losses? - loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) - if self.training and losses is not None: - if dpo_loss is not None: - losses[self._dpo_loss_name].append(dpo_loss.detach()) - if self._config.distillation_model is not None and distillation_loss is not None: - losses[self._distillation_loss_name].append(distillation_loss.detach()) - if self._config.distillation_model is not None and lm_loss is not None: - losses[self._distillation_language_model_loss_name].append(lm_loss.detach()) + if losses is not None and total_loss is not None: + losses[self._total_head_loss_name].append(total_loss.detach()) - return loss, output_parallel_linear_backward(grad, context) if self.training else None + return total_loss, output_parallel_linear_backward(grad, context) if self.training else None @functools.cached_property - def _loss_name(self) -> str: - name = "language_model_loss" + def _total_head_loss_name(self) -> str: + """ + Combined total scaled loss used for training. + """ + name = "lm_head_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @@ -441,54 +390,21 @@ def _z_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _dpo_loss_name(self) -> str: - name = "dpo_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_language_model_loss_name(self) -> str: - name = "distillation_language_model_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_loss_name(self) -> str: - name = "distillation_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] - if self._config.logit_z_loss: + loss_defs = [ + LossDef( + name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count + ) + ] + if self._config.logit_z_loss > 0.0: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) ) - if self._config.enable_dpo: - loss_defs.append( - LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) + for loss_name, loss_config in self._config.losses.items(): + loss_def: LossDef = loss_config.get_loss_def( + name=loss_name, count=count, prediction_distance=self._prediction_distance ) - - if self._config.distillation_model is not None: - loss_defs.append( - LossDef( - name=self._distillation_loss_name, - formatted_name=_format_name(self._distillation_loss_name), - count=count, - ) - ) - if self._config.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef( - name=self._distillation_language_model_loss_name, - formatted_name=_format_name(self._distillation_language_model_loss_name), - count=count, - ) - ) + loss_defs.append(loss_def) return loss_defs @@ -496,17 +412,3 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: def heads(self): # For compatibility with MTP. return [self] - - -def _format_name(name: str) -> str: - return name.replace("_", " ") - - -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: - tensors = [tensor for tensor in tensors if tensor is not None] - if len(tensors) > 1: - return sum(tensors) - elif len(tensors) == 1: - return tensors[0] - else: - raise RuntimeError() diff --git a/fast_llm/layers/language_model/kwargs.py b/fast_llm/layers/language_model/kwargs.py new file mode 100644 index 00000000..4f620388 --- /dev/null +++ b/fast_llm/layers/language_model/kwargs.py @@ -0,0 +1,23 @@ +from fast_llm.layers.block.config import BlockKwargs + + +class TargetsKwargs: + lm_target = "preprocessed_lm_target" + dpo_target = "preprocessed_dpo_target" + reference_model_logits = "reference_model_logits" + dpo_reference_model_logits = "dpo_reference_model_logits" + + +class LanguageModelKwargs(BlockKwargs): + token_ids = "token_ids" + position_ids = "position_ids" + token_map = "token_map" + sample_map = "sample_map" + embedding_map = "embedding_map" + # TODO: These are generic + labels = "labels" + phase = "phase" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" + loss_mask = "loss_mask" + mask_inputs = "mask_inputs" diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py new file mode 100644 index 00000000..f6e69b4f --- /dev/null +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -0,0 +1,341 @@ +import abc +import logging +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs, TargetsKwargs +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.core.distributed import ProcessGroup + from fast_llm.layers.language_model.config import LanguageModelHeadConfig + +logger = logging.getLogger(__name__) + +# +# CE loss on lm_targets for standard LM training. Here targets are already masked. +# CE loss for distillation: cross entropuy that uses reference_model_logits as soft targets, not implemented, TODO. +# Forward KL divergence loss on reference_model_logits for distillation (mode-covering). +# Reverse KL divergence loss on reference_model_logits for distillation (mode-seeking). +# DPO loss for alignment using chosen and rejected spans. +# + + +def _format_name(name: str) -> str: + return name.replace("_", " ") + + +@config_class(registry=True) +class LanguageModelLossConfig(Config): + """ + Losses can register themselves using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}). + """ + + _name: typing.ClassVar[str] + _abstract: typing.ClassVar[bool] = True + + weight: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Weight for this loss in the total loss computation.", + valid=check_field(Assert.geq, 0.0), + ) + + @abc.abstractmethod + def compute_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + pass + + def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: + name = self.get_formatted_name(name, prediction_distance) + return LossDef( + name=name, + formatted_name=_format_name(name), + count=count, + dtype=DataType.float32, + ) + + def _validate(self): + Assert.geq(self.weight, 0.0) + super()._validate() + + def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: + name = f"{self._name}({name})" + if prediction_distance is not None: + name = f"{name}_{prediction_distance}" + return name + + @abc.abstractmethod + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + pass + + +@config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) +class CrossEntropyLMLossConfig(LanguageModelLossConfig): + _name: typing.ClassVar[str] = "CE" + _abstract: typing.ClassVar[bool] = False + + implementation: CrossEntropyImpl = Field( + default=CrossEntropyImpl.auto, + desc="Implementation for the cross-entropy computation.", + hint=FieldHint.performance, + ) + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax (used in distillation losses).", + valid=check_field(Assert.gt, 0.0), + ) + + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + lm_target = kwargs.get(LanguageModelKwargs.labels) + if lm_target is not None: + # MTP: Shift the labels + lm_target_sequence_length = ( + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads + ) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + lm_target_slice = slice(prediction_distance, prediction_distance + lm_target_sequence_length) + lm_target = ( + lm_target[lm_target_slice] + if kwargs[LanguageModelKwargs.sequence_first] + else lm_target[:, lm_target_slice] + ).flatten() + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + lm_target = split_op(lm_target, group, 0) + return {TargetsKwargs.lm_target: lm_target} + + def compute_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import cross_entropy_forward_backward + + target = kwargs.get(TargetsKwargs.lm_target) + implementation = self.implementation + if implementation == CrossEntropyImpl.auto: + if vocab_parallel: + implementation = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + implementation = CrossEntropyImpl.triton + else: + implementation = CrossEntropyImpl.fused + + return cross_entropy_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, # Labels are already masked + grad_output=grad_output, + group=group, + implementation=implementation, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.labels, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) +class ForwardKLLossConfig(LanguageModelLossConfig): + """Forward KL divergence KL(p||q) for distillation (mode-covering).""" + + _name: typing.ClassVar[str] = "FwdKL" + _abstract: typing.ClassVar[bool] = False + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{head_config.distillation_model}_logits") + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + reference_model_logits = split_op(reference_model_logits, group, 0) + return {TargetsKwargs.reference_model_logits: reference_model_logits} + + def compute_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import forward_kl_forward_backward + + target = kwargs.get(TargetsKwargs.reference_model_logits) + + return forward_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) +class ReverseKLLossConfig(ForwardKLLossConfig): + """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" + + _name: typing.ClassVar[str] = "RevKL" + _abstract: typing.ClassVar[bool] = False + + def compute_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import reverse_kl_forward_backward + + # Use distillation_target for KL losses + target = kwargs.get(TargetsKwargs.reference_model_logits) + + return reverse_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) +class DPOLossConfig(LanguageModelLossConfig): + """Direct Preference Optimization (DPO) loss for alignment.""" + + _name: typing.ClassVar[str] = "DPO" + _abstract: typing.ClassVar[bool] = False + + beta: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Beta parameter for DPO loss (controls strength of preference optimization).", + valid=check_field(Assert.gt, 0.0), + ) + + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{head_config.dpo_reference_model}_logits") + dpo_target = kwargs.get(LanguageModelKwargs.labels) + if reference_model_logits is not None or dpo_target is not None: + from fast_llm.core.ops import split_op + + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + if dpo_target is not None: + dpo_target = split_op(dpo_target, group, 0) + return { + TargetsKwargs.dpo_reference_model_logits: reference_model_logits, + TargetsKwargs.dpo_target: dpo_target, + } + + def compute_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.dpo import compute_dpo_loss + + dpo_target = kwargs.get(TargetsKwargs.dpo_target) + dpo_reference_model_logits = kwargs.get(TargetsKwargs.dpo_reference_model_logits) + chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) + rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) + + return compute_dpo_loss( + logits=logits, + targets=dpo_target, + reference_model_logits=dpo_reference_model_logits, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + beta=self.beta, + grad_output=grad_output, + ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2f43d1e4..846c6564 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -12,7 +12,7 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 890d5760..88da79e6 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.vision.config import VisionKwargs from fast_llm.layers.vision.vision_encoder import VisionMultiModalModel from fast_llm.models.gpt.config import GPTBatchConfig diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 72644d06..716c56ba 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -8,7 +8,11 @@ import torch from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward +from fast_llm.functional.cross_entropy import ( + cross_entropy_forward_backward, + forward_kl_forward_backward, + reverse_kl_forward_backward, +) from fast_llm.utils import Assert from tests.utils.utils import requires_cuda @@ -127,6 +131,41 @@ def test_reverse_kl(loss_masking, target_format): _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) +def _forward_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None): + # Manual reference: sum over vocab then average over all tokens (not just valid ones). + # Forward KL: KL(p||q) where p=teacher, q=student + logits = logits.detach().requires_grad_(True) + per_sample = torch.nn.functional.kl_div( + torch.log_softmax(logits.float(), dim=-1), + torch.log_softmax(target.float(), dim=-1), + reduction="none", + log_target=True, + ).sum(dim=-1) + if loss_mask is not None: + per_sample = per_sample * loss_mask + output = per_sample.sum() / per_sample.numel() + output.backward() + return output, logits.grad + + +@requires_cuda +@pytest.mark.slow +# TODO: Support the same parameterization as above in the reference implementation. +@pytest.mark.parametrize("loss_masking", [False, True]) +@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) +def test_forward_kl(loss_masking, target_format): + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + out_ref, grad_ref = _forward_kl_forward_backward_torch(logits, target, loss_mask) + out, grad = forward_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=1.0, + target_format=TargetFormat.logits, + ) + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) + + def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): try: torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) @@ -189,7 +228,7 @@ def _compare_parallel_cross_entropy( def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): success = True - for function in (reverse_kl_forward_backward, cross_entropy_forward_backward): + for function in (reverse_kl_forward_backward, forward_kl_forward_backward, cross_entropy_forward_backward): for target_format in (TargetFormat.logits,): for loss_masking in [False, True]: try: diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 623a30d8..ed639db9 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -5,10 +5,12 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -43,6 +45,20 @@ def _reverse_kl_loss( return loss +def _kl_loss( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + teacher_softmax_temperature: float = 1.0, +): + return _reverse_kl_loss( + target, + logits, + loss_mask, + teacher_softmax_temperature, + ) + + def _lm_head( input_: torch.Tensor, target: torch.Tensor, @@ -54,7 +70,7 @@ def _lm_head( grad_output: float = 1.0, logit_scale_factor: float = 1.0, logit_z_loss=0.0, - distillation_loss_implementation: DistillationLossImpl = DistillationLossImpl.cross_entropy, + losses: dict[str, LanguageModelLossConfig], ): hidden = torch.rms_norm( input_.to(rms_weight.dtype), @@ -64,28 +80,34 @@ def _lm_head( ) logits = torch.nn.functional.linear(hidden, logit_weight).float() - if distillation_loss_implementation == DistillationLossImpl.reverse_kl: - Assert.eq(logits.shape, target.shape) - loss = _reverse_kl_loss( - (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask - ) - loss.backward(torch.full_like(loss, grad_output)) - return loss, None + if "dist_loss" in losses: + if losses["dist_loss"].type == "reverse_kl_distillation": + Assert.eq(logits.shape, target.shape) + loss = _reverse_kl_loss( + (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask + ) + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].weight)) + # Return scaled loss + return loss * losses["dist_loss"].weight, None + elif losses["dist_loss"].type == "forward_kl_distillation": + Assert.eq(logits.shape, target.shape) + loss = _kl_loss( + (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask + ) + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].weight)) + # Return scaled loss + return loss * losses["dist_loss"].weight, None if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None - if target.ndim == logits.ndim: - loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" - ) - if loss_mask is not None: - loss = loss * loss_mask.flatten() - loss = loss.mean() - else: - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - loss.backward(torch.full_like(loss, grad_output)) - return loss, z_loss + # Language model loss (cross-entropy with hard labels) + loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) + # Apply language_model_loss_factor + loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].weight)) + return loss * losses["lm_loss"].weight, z_loss SEQUENCE_LENGTH = 200 @@ -109,50 +131,96 @@ def _lm_head( ({"tied_embedding_weight": True}, {}, False, 1), ({}, {}, False, 2), ({}, {}, True, 1), - ( + # Skip CE distillation for now - not yet implemented in new losses system + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "losses": { + # "lm_loss": { + # "type": "cross_entropy", + # "weight_scalor": 0.0, + # }, + # "dist_loss": { + # "type": "cross_entropy_dist", # TODO: Not implemented yet + # "weight_scalor": 1.0, + # } + # } + # } + # }, + # {}, + # False, + # 1, + # ), + pytest.param( { "head": { "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + "losses": { + "lm_loss": { + "type": "cross_entropy", + "weight": 0.0, + }, + "dist_loss": { + "type": "reverse_kl_distillation", + "weight": 1.0, + }, + }, } }, {}, False, 1, + id="track_lm_zero_factor", ), - ( + pytest.param( { "head": { "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + "losses": { + "lm_loss": { + "type": "cross_entropy", + "weight": 0.0, + }, + "dist_loss": { + "type": "reverse_kl_distillation", + "weight": 0.0, + }, + }, } }, {}, False, 1, + marks=pytest.mark.xfail( + reason="Cannot track both losses with zero factor", + strict=True, + ), + id="track_both_zero_factors", ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - "language_model_loss_factor": 1.0, - } - }, - {}, - True, - 1, - ), - ( + pytest.param( { "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + "losses": { + "lm_loss": { + "type": "cross_entropy", + "weight": 1.0, + }, + "dist_loss": { + "type": "reverse_kl_distillation", + "weight": 1.0, + }, + }, } }, {}, - True, + False, 1, + marks=pytest.mark.xfail( + reason="Cannot track distillation loss without distillation model being set", + strict=True, + ), + id="track_distillation_without_model", ), ), ) @@ -164,8 +232,14 @@ def test_lm_head( prediction_heads: int, ): head_config = { - "cross_entropy_implementation": cross_entropy_impl, "normalization": {"type": "rms_norm"}, + "losses": { + "lm_loss": { + "type": "cross_entropy", + "implementation": cross_entropy_impl, + "weight": 1.0, + } + }, } config = GPTBaseModelConfig.from_dict( { @@ -222,19 +296,19 @@ def test_lm_head( AttentionKwargs.sequence_first: sequence_first, AttentionKwargs.grad_output: 1.0, } - if head_config.distillation_model is None: - target = torch.randint( - 0, - VOCAB_SIZE, - label_shape, - dtype=torch.int64, - device=distributed.device, - ) - if loss_mask is not None: - target *= loss_mask + # always set lm targets + target = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.int64, + device=distributed.device, + ) + if loss_mask is not None: + target *= loss_mask - kwargs[LanguageModelKwargs.labels] = target - else: + kwargs[LanguageModelKwargs.labels] = target + if head_config.distillation_model is not None: assert config.head.max_prediction_distance == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), @@ -291,7 +365,7 @@ def test_lm_head( logit_weight=ref_logit_weight, logit_scale_factor=head_config.logits_scale_factor, logit_z_loss=head_config.logit_z_loss, - distillation_loss_implementation=head_config.distillation_loss_implementation, + losses=head_config.losses, ) # Prepare LM head inputs @@ -303,20 +377,22 @@ def test_lm_head( head_input = torch.stack((shared_hidden, input_.detach())).requires_grad_() output_grad = torch.randn_like(shared_hidden) - loss_name = f"language_model_loss_{prediction_distance}" if prediction_distance > 0 else "language_model_loss" - loss_keys = {loss_name} + lm_head_loss_name = f"lm_head_loss_{prediction_distance}" if prediction_distance > 0 else "lm_head_loss" + expected_loss_keys = {lm_head_loss_name} + + # Get expected loss names from the loss configs + for loss_name, loss_config in head._config.losses.items(): + formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) + expected_loss_keys.add(formatted_name) + if ref_z_loss is not None: - loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - if head_config.distillation_model is not None: - loss_keys.add("distillation_loss") - if head_config.language_model_loss_factor > 0: - loss_keys.add("distillation_language_model_loss") + expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") Assert.eq( {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, - {loss_key: 1 for loss_key in loss_keys}, + {loss_key: 1 for loss_key in expected_loss_keys}, ) - losses = {key: [] for key in loss_keys} + losses = {key: [] for key in expected_loss_keys} output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) @@ -325,16 +401,16 @@ def test_lm_head( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 ) * head_config.logits_scale_factor - Assert.eq(losses.keys(), loss_keys) - Assert.eq(len(losses[loss_name]), 1) + Assert.eq(losses.keys(), expected_loss_keys) + Assert.eq(len(losses[lm_head_loss_name]), 1) if ref_z_loss is not None: Assert.eq(len(losses["z_loss"]), 1) Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) - Assert.rms_close_relative(losses[loss_name][0], ref_loss, threshold, min_threshold) + Assert.rms_close_relative(losses[lm_head_loss_name][0], ref_loss, threshold, min_threshold) if head._is_last_head: - Assert.all_equal(output, losses[loss_name][0]) + Assert.all_equal(output, losses[lm_head_loss_name][0]) input_grad = head_input.grad else: Assert.all_equal(output, shared_hidden) diff --git a/tests/test_config.py b/tests/test_config.py index 4020b6fb..3c6a76a3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -148,12 +148,15 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, + "head": {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, } else: expected_config["base_model"] = base_model_update + # added by default + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) @@ -296,3 +299,7 @@ def test_distributed_global_ranks(bdp: int, sdp: int, tp: int, pp: int, pipeline Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks}), world_size) Assert.eq(len(rank_breakdowns), world_size) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1248a111..f3d4659c 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -244,7 +244,12 @@ def _update_and_add_testing_config( }, "num_blocks": 2, }, - "head": {"output_weight": init_1}, + "head": { + "output_weight": init_1, + "losses": { + "lm_loss": {"type": "cross_entropy", "weight": 1.0}, + }, + }, "hidden_size": 256, "tied_embedding_weight": True, }, @@ -557,6 +562,12 @@ def _update_and_add_testing_config( "mistral_distill_logits", updates={ ("model", "base_model", "head", "distillation_model"): "teacher", + ("model", "base_model", "head", "losses"): { + "distillation_loss": { + "type": "reverse_kl_distillation", + "factor": 1.0, + }, + }, ("batch", "use_loss_masking_spans"): True, ("reference_models"): { "teacher": { @@ -579,32 +590,11 @@ def _update_and_add_testing_config( skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), ) -_update_and_add_testing_config( - "mistral_distill_logits", - "mistral_reverse_kl", - updates={ - ("model", "base_model", "head", "distillation_loss_implementation"): "reverse_kl", - }, - megatron_args=None, - checkpoint_format=MistralCheckpointFormat, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, - ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, - ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, - ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, tp2, stp2, stp2_ce4 - }, - compare_factor=2, - # Modes not supported with reference models - skip_tests=("sdp", "ms", "pp"), -) - _update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", updates={ - ("model", "base_model", "head", "distillation_loss_factor"): 0.001, + ("model", "base_model", "head", "losses", "distillation_loss", "factor"): 0.001, ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.1, ("reference_models"): {