Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c335f6e
train with only layer distillation losses
oleksost Dec 16, 2025
e06a4b2
unscaled loss llogging + training with distillation loss factor = 0
oleksost Dec 16, 2025
179ae25
make logging more explicit
oleksost Dec 17, 2025
af456f0
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 17, 2025
9968aac
clean + tests
oleksost Dec 17, 2025
945c5a7
nvm
oleksost Dec 17, 2025
4b6e3d7
forward KL
oleksost Dec 19, 2025
c5fefa0
test forward kl
oleksost Dec 19, 2025
4119596
wip: report unscaled + kl loss
oleksost Dec 19, 2025
b55a0a4
loss config
oleksost Dec 22, 2025
097baeb
wip
oleksost Dec 22, 2025
d773d98
tests
oleksost Dec 22, 2025
35400c1
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 22, 2025
282925c
test
oleksost Dec 22, 2025
0f73ea2
tests
oleksost Dec 22, 2025
04a0193
Merge branch 'main' into train_only_layer_losses
oleksost Dec 22, 2025
fa85c41
wip
oleksost Dec 22, 2025
feb416e
Merge branch 'train_only_layer_losses' of https://github.com/ServiceN…
oleksost Dec 22, 2025
31cfb84
wip
oleksost Dec 23, 2025
24fe67b
no grad if factor 0
oleksost Dec 23, 2025
00f6118
Merge remote-tracking branch 'origin/main' into train_only_layer_losses
oleksost Dec 23, 2025
0cadf98
Merge branch 'main' into train_only_layer_losses
oleksost Dec 23, 2025
0e562e9
addressed comments
oleksost Dec 23, 2025
2a474e2
Merge branch 'train_only_layer_losses' of https://github.com/ServiceN…
oleksost Dec 23, 2025
52c1c11
addressed comments
oleksost Dec 23, 2025
406d0a2
Removed Targets class
oleksost Dec 30, 2025
f25380a
fixes
oleksost Dec 30, 2025
8adb7dd
imports
oleksost Dec 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy
def _load_config(self) -> SampledDatasetConfig[SampleType]:
assert self.path.is_file(), f"File {self.path} does not exist."
config = yaml.safe_load(self.path.open("r"))
Assert.eq(config.keys(), {"config", "metadata"})
if config.keys() == {"config", "metadata"}:
# Newer format with metadata
config = config["config"]
Expand Down
5 changes: 0 additions & 5 deletions fast_llm/functional/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ class CrossEntropyImpl(str, enum.Enum):
triton = "triton"


class DistillationLossImpl(str, enum.Enum):
reverse_kl = "reverse_kl"
cross_entropy = "cross_entropy"


class TargetFormat(enum.StrEnum):
labels = "labels"
logits = "logits"
Expand Down
62 changes: 61 additions & 1 deletion fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def _fused_cross_entropy_forward_backward(
target_format: TargetFormat,
group: ProcessGroup | None = None,
teacher_softmax_temperature: float = 1.0,
return_target_entropy: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
A fused implementation of cross-entropy with torch compile.
Expand Down Expand Up @@ -158,6 +159,16 @@ def _fused_cross_entropy_forward_backward(
loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.AVG, group=group)
if return_target_entropy and target_format == TargetFormat.logits:
# Compute teacher entropy
teacher_log_prob = torch.log(target + 1e-20)
target_entropy = -(target * teacher_log_prob).sum(dim=-1)
if loss_mask is not None:
target_entropy = target_entropy * loss_mask.squeeze(-1)
target_entropy = target_entropy.mean()
if group is not None:
all_reduce(target_entropy, op=ReduceOp.SUM, group=group)
return loss, grad, target_entropy

return loss, grad

Expand Down Expand Up @@ -236,7 +247,6 @@ def _reverse_kl_forward_backward(
group: ProcessGroup | None = None,
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Reverse KL using PyTorch's native kl_div function.
Expand Down Expand Up @@ -359,3 +369,53 @@ def reverse_kl_forward_backward(
group=group,
)
return distillation_loss, distillation_grad


def forward_kl_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor | None,
grad_output: float | None,
group: ProcessGroup | None = None,
logits_scale_factor: float = 1.0,
teacher_softmax_temperature: float = 1.0,
target_format: TargetFormat = TargetFormat.labels,
sequence_parallel_logits: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student).
This is mode-covering (vs. mode-seeking for reverse KL) and useful for:
- Encouraging the model to cover all modes of the target distribution
- Spreading probability mass broadly across the target support
- Standard distillation scenarios where you want to match the full teacher distribution

Key differences from reverse KL:
- Forward KL: KL(p||q) = mode-covering (spreads mass broadly)
- Reverse KL: KL(q||p) = mode-seeking (focuses on target modes)

Takes:
logits: [BxS, V] or [B, S, V], where V is local vocab size
target: [BxS, V] or [B, S, V] (logits format)
loss_mask: [BxS] or [B, S] or None
...

Returns:
loss: Forward KL divergence loss
grad: Gradients w.r.t. logits
"""
assert target_format == TargetFormat.logits, "Forward KL only supports logits format"
Assert.eq(target.shape, logits.shape)
distillation_loss, distillation_grad, teacher_entropy = _fused_cross_entropy_forward_backward(
logits=logits,
target=target,
loss_mask=loss_mask,
grad_output=grad_output,
logits_scale_factor=logits_scale_factor,
target_format=target_format,
group=group,
teacher_softmax_temperature=teacher_softmax_temperature,
return_target_entropy=True,
)
distillation_loss -= teacher_entropy

return distillation_loss, distillation_grad
82 changes: 24 additions & 58 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl
from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig
from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig
from fast_llm.layers.common.normalization.config import NormalizationConfig
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig
from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
Expand All @@ -19,21 +19,6 @@
from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction


class LanguageModelKwargs(BlockKwargs):
token_ids = "token_ids"
position_ids = "position_ids"
token_map = "token_map"
sample_map = "sample_map"
embedding_map = "embedding_map"
# TODO: These are generic
labels = "labels"
phase = "phase"
chosen_spans = "chosen_spans"
rejected_spans = "rejected_spans"
loss_mask = "loss_mask"
mask_inputs = "mask_inputs"


@config_class()
class LanguageModelEmbeddingsConfig(BlockConfig):
_abstract = False
Expand Down Expand Up @@ -135,44 +120,22 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig):
desc="Configuration for the final normalization layer.",
hint=FieldHint.architecture,
)
losses: dict[str, LanguageModelLossConfig] = Field(
default_factory=dict,
desc="A dictionary of loss names and their configurations.",
hint=FieldHint.core,
)
# TODO: Cleanup
output_weight: ParameterConfig = Field(
desc="Configuration for the LM output layer (weight). Ignored for tied embeddings",
hint=FieldHint.architecture,
)
cross_entropy_implementation: CrossEntropyImpl = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These removals are likely to cause backward compatibility issues when loading existing models. Please make sure it doesn't disrupt ongoing work, and if needed add backward compatibility in _validate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested training with checkpoints created on the main branch in both distributed and apriel2 format. Training starts with no issues.

default=CrossEntropyImpl.auto,
desc="Implementation for the cross-entropy computation.",
hint=FieldHint.performance,
)
distillation_loss_implementation: DistillationLossImpl = Field(
default=DistillationLossImpl.cross_entropy,
desc="Implementation for the distillation cross-entropy computation.",
hint=FieldHint.performance,
)
cross_entropy_splits: int | None = Field(
default=None,
desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.",
hint=FieldHint.feature,
valid=skip_valid_if_none(check_field(Assert.gt, 0)),
)
logit_z_loss: float = Field(
default=0.0,
desc="Regularize the logits with Z-loss.",
doc="We recommend 1e-4 for stability, as used for training PaLM.",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
language_model_loss_factor: float = Field(
default=None,
desc="Factor to scale the language modeling loss by when using distillation.",
hint=FieldHint.feature,
)
distillation_loss_factor: float = Field(
default=1.0,
desc="Factor to scale the distillation loss by when using distillation.",
hint=FieldHint.feature,
)
logits_scale_factor: float = Field(
default=1.0,
desc="Multiply output logits by scale factor.",
Expand All @@ -181,10 +144,10 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig):
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
teacher_softmax_temperature: float = Field(
default=1.0,
desc="Divides distillation target logits by this factor.",
doc="Divides distillation target logits by this factor.",
logit_z_loss: float = Field(
default=0.0,
desc="Regularize the logits with Z-loss.",
doc="We recommend 1e-4 for stability, as used for training PaLM.",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
Expand All @@ -193,11 +156,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig):
desc="Name of the reference model to use for dpo.",
hint=FieldHint.feature,
)
dpo_beta: float | None = Field(
default=1.0,
desc="Beta value for DPO loss.",
hint=FieldHint.feature,
)
distillation_model: str | None = Field(
default=None,
desc="Name of the reference model to use for knowledge distillation."
Expand Down Expand Up @@ -237,11 +195,19 @@ def layer_class(self) -> "type[LanguageModelHead]":

def _validate(self) -> None:
with self._set_implicit_default():
if self.language_model_loss_factor is None:
if self.distillation_model is None:
self.language_model_loss_factor = 1.0
else:
self.language_model_loss_factor = 0.0
if not self.losses:
if "losses" not in self._explicit_fields:
self.losses = {
"lm_loss": LanguageModelLossConfig._from_dict(
{
"type": "cross_entropy",
"weight": 1.0,
}
)
}
for loss_config in self.losses.values():
if "distillation" in loss_config.type:
assert self.distillation_model is not None, "Distillation loss requires a distillation model."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the distillation model go with the loss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, this raises error when there is no distillation mode, this is correct, no?

super()._validate()
assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both

Expand Down
3 changes: 2 additions & 1 deletion fast_llm/layers/language_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.layers.block.block import Block
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs
from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig
from fast_llm.layers.language_model.kwargs import LanguageModelKwargs
from fast_llm.tensor import TensorMeta
from fast_llm.utils import Assert

Expand Down
Loading