From 00a3327e475b0a940438ab11c8bb4cdf13112adc Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 11 Dec 2025 22:08:40 +0000 Subject: [PATCH 01/19] fixes masked loss distillation --- fast_llm/data/sample/range.py | 2 +- fast_llm/models/gpt/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 8dd351e1..22d5e899 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -38,7 +38,7 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: sample_size = 0 for document in documents: for begin, end in document.ranges: - ranges.extend((begin + sample_size, end + sample_size)) + ranges.append((begin + sample_size, end + sample_size)) sample_size += document.sample_size return cls(ranges, sample_size) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a0c38143..41a59ca1 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -247,7 +247,7 @@ def preprocess_batch( for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): for begin, end in loss_masking_spans: loss_mask[sample_index, begin:end] = False - if self._config.output_layer.distillation_model is not None: + if self._config.decoder.block.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) From ba2c0618476955059f0ffb4182408667dee837e6 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 12 Dec 2025 16:48:57 +0000 Subject: [PATCH 02/19] test forward with loss masks --- fast_llm/data/sample/range.py | 3 +++ fast_llm/models/gpt/model.py | 2 +- tests/utils/dataset.py | 9 +++++++-- tests/utils/model_configs.py | 24 +++++++++++++++++++++++- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 22d5e899..a7784672 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -33,6 +33,9 @@ def __init__(self, ranges: list[tuple[int, int]], sample_size: int): @classmethod def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + """ + Used to merge ranges from multiple documents, i.e. when multiple docuemnts are packed together. + """ document: RangeSample ranges = [] sample_size = 0 diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 41a59ca1..fd8d2af1 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -247,7 +247,7 @@ def preprocess_batch( for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): for begin, end in loss_masking_spans: loss_mask[sample_index, begin:end] = False - if self._config.decoder.block.distillation_model is not None: + if self._config.head.distillation_model is not None: kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index e39b74fa..be44ae61 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -226,7 +226,7 @@ def _get_test_dataset( preparator_config.run() config = ( - {"type": "file", "path": config_paths[0]} + {"type": "file", "path": config_paths[0]} # TODO: shouldn't this be {"training": {...}}? if splits is None else { split: {"type": "file", "path": config_path} @@ -284,7 +284,12 @@ def get_test_dataset_with_loss_masking_spans( config_only: bool = False, ) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig]: return _get_test_dataset( - DATASET_CACHE / "dataset_with_loss_masking_spans", seed=1234, max_loss_masking_spans=5, config_only=config_only + DATASET_CACHE / "dataset_with_loss_masking_spans", + seed=1234, + max_vocab_size=MODEL_TEST_VOCAB_SIZE, + max_loss_masking_spans=5, + splits={"training": 969, "validation": 30, "test": 1}, + config_only=config_only, ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index e943dc96..f48a4467 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -26,7 +26,11 @@ Qwen2CheckpointFormat, ) from fast_llm.models.multimodal.conversion.config import Apriel2CheckpointFormat, LlavaCheckpointFormat -from tests.utils.dataset import get_model_test_dataset, get_multimodal_test_dataset +from tests.utils.dataset import ( + get_model_test_dataset, + get_multimodal_test_dataset, + get_test_dataset_with_loss_masking_spans, +) from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import MODEL_TEST_VOCAB_SIZE @@ -403,6 +407,24 @@ def _update_and_add_testing_config( }, ) +_update_and_add_testing_config( + "llama", + "llama_with_loss_masking", + updates={ + ("batch", "use_loss_masking_spans"): True, + }, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, + ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, + ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, + ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + }, + compare_factor=1.5, # Loss masking seem to induce slight numerical variation between dtypes + get_dataset=get_test_dataset_with_loss_masking_spans, +) + _update_and_add_testing_config( # Tests yarn-style rotary embeddings. "llama", From 493fe879636f7f77eb4dcd23e4135f787834aeff Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 12 Dec 2025 18:48:41 +0000 Subject: [PATCH 03/19] fix kda test --- tests/layers/test_ssm.py | 41 ++++++++++++---------------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index e6422c59..515b89a8 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -10,9 +10,7 @@ from fast_llm.layers.ssm import kda as kda_module from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, Mamba2Config from fast_llm.utils import Assert -from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba -from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig -from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba, KimiDeltaAttention from tests.utils.utils import get_stage, requires_cuda HIDDEN_SIZE = 16 @@ -102,39 +100,24 @@ def test_gdn(): @pytest.mark.slow @requires_cuda -@pytest.mark.skipif(KimiDeltaAttention is None or AprielHybridSSMConfig is None, reason="Apriel KDA deps missing") @pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") def test_kda(): NUM_HEADS = 4 HEAD_DIM = 4 KERNEL_SIZE = 4 - hf_config = AprielHybridSSMConfig( - hidden_size=HIDDEN_SIZE, - num_attention_heads=NUM_HEADS, - num_hidden_layers=1, - rms_norm_eps=1e-6, - ) - hf_config.short_conv_kernel_size = KERNEL_SIZE - hf_config.head_dim = HEAD_DIM - hf_config.num_heads = NUM_HEADS - hf_layer = KimiDeltaAttention(hf_config, layer_idx=0) - - fast_llm_config = KimiDeltaAttentionConfig( - heads=NUM_HEADS, - head_dim=HEAD_DIM, - convolution_layer={"kernel_size": KERNEL_SIZE, "activation": "silu"}, - normalization={"epsilon": 1e-6, "activation": "sigmoid"}, - ) - - param_map = { - "q_conv.weight": "q_conv1d.weight", - "k_conv.weight": "k_conv1d.weight", - "v_conv.weight": "v_conv1d.weight", - "beta_proj.weight": "b_proj.weight", - "norm.weight": "o_norm.weight", + kda_config = { + "heads": NUM_HEADS, + "head_dim": HEAD_DIM, + "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, + "normalization": {"epsilon": 1e-5, "activation": "sigmoid"}, } - _compare_mixers(fast_llm_config, hf_layer, param_map) + + hf_layer = KimiDeltaAttention(HIDDEN_SIZE, kda_config, layer_idx=0) + + fast_llm_config = KimiDeltaAttentionConfig.from_dict(kda_config, {}) + + _compare_mixers(fast_llm_config, hf_layer, {}) @pytest.mark.slow From c68a7429b5b39874aac8bee6044a5b59e48428f0 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 12 Dec 2025 19:05:58 +0000 Subject: [PATCH 04/19] varlen test fix --- tests/layers/test_varlen.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index 32cd00cd..54f03958 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -22,12 +22,14 @@ "config", [ AttentionConfig(heads=4, head_groups=2, head_size=16, cross_document_attention=False), - Mamba2Config( - d_inner=128, - d_xb=64, - state_size=16, - dt_rank=8, - cross_document_attention=False, + pytest.param( + Mamba2Config( + d_inner=128, + d_xb=64, + state_size=16, + dt_rank=8, + cross_document_attention=False, + ), marks=pytest.mark.skip("Mamba varlen kernel not available"), ), pytest.param( From daba344e1a1f012e9cc770c6706f29ad0459da09 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 12 Dec 2025 23:28:54 +0000 Subject: [PATCH 05/19] manual kl grad computation --- fast_llm/functional/cross_entropy.py | 20 +++++++++++++++----- tests/functional/test_cross_entropy.py | 11 ++++++----- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 42b0c214..8bc56349 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -58,11 +58,13 @@ def _fused_softmax_base( logits *= logits_scale_factor logits_max = torch.max(logits, dim=dim, keepdim=True)[0] if group is not None: + # Use autograd-aware all_reduce with correct gradient behavior all_reduce(logits_max, op=ReduceOp.MAX, group=group) logits_norm = (logits - logits_max).float() exp_logits = logits_norm.exp() sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True) if group is not None: + # Use autograd-aware all_reduce with correct gradient behavior all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) return logits_norm, exp_logits, sum_exp_logits @@ -227,7 +229,7 @@ def distributed_log_softmax( return logits_norm - sum_exp_logits.log() # log_softmax -def _torch_reverse_kl_forward_backward( +def _reverse_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, @@ -261,7 +263,6 @@ def _torch_reverse_kl_forward_backward( # Compute log probabilities teacher_log_probs = distributed_log_softmax(target.float(), group=group) - # batch_size = logits.shape[0] with torch.enable_grad(): logits_ = logits.float().detach().requires_grad_(grad_output is not None) student_log_probs = distributed_log_softmax(logits_, group=group) @@ -287,8 +288,17 @@ def _torch_reverse_kl_forward_backward( loss /= valid_tokens if grad_output is not None: - loss.backward(torch.full_like(loss, grad_output)) - grad = logits_.grad.to(logits.dtype) + log_ratio = student_log_probs - teacher_log_probs + expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True) + grad_base = torch.exp(student_log_probs) * (log_ratio - expected) + + # Apply mask to gradients, not to log_probs! + if loss_mask is not None: + valid = loss_mask.to(logits.dtype).unsqueeze(-1) + grad_base = grad_base * valid + + grad = grad_base.mul(grad_output / valid_tokens) + grad = grad.to(logits.dtype) else: grad = None @@ -339,7 +349,7 @@ def reverse_kl_forward_backward( Assert.eq(loss_mask.shape, logits.shape[:-1]) # TODO: implement fused? - distillation_loss, distillation_grad = _torch_reverse_kl_forward_backward( + distillation_loss, distillation_grad = _reverse_kl_forward_backward( logits=logits, target=target, loss_mask=loss_mask, diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index a23b49f8..b4ea6964 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -40,10 +40,11 @@ def _compare_cross_entropy_outputs( grad: torch.Tensor | None, ref_grad: torch.Tensor | None, threshold=1e-5, + min_threshold_grads=1e-8, ): Assert.rms_close_relative(loss, ref_loss, threshold, 1e-6) if has_grad: - Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) + Assert.rms_close_relative(grad, ref_grad, threshold, min_threshold_grads) else: assert grad is None assert ref_grad is None @@ -114,8 +115,8 @@ def _reverse_kl_forward_backward_torch(target: torch.Tensor, logits: torch.Tenso @pytest.mark.parametrize("loss_masking", [False, True]) @pytest.mark.parametrize("target_format", (TargetFormat.logits,)) def test_reverse_kl(loss_masking, target_format): - logits, target, loss_mask = _get_cross_entropy_inputs(10000, loss_masking, target_format) - out_ref, grad_ref = _reverse_kl_forward_backward_torch(logits, target, loss_mask) + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + out_ref, grad_ref = _reverse_kl_forward_backward_torch(target, logits, loss_mask) out, grad = reverse_kl_forward_backward( logits=logits, target=target, @@ -184,12 +185,12 @@ def _compare_parallel_cross_entropy( grad_output=1, target_format=target_format, ) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4, 1e-6) def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): success = True - for function in (cross_entropy_forward_backward, reverse_kl_forward_backward): + for function in (reverse_kl_forward_backward, cross_entropy_forward_backward): for target_format in (TargetFormat.logits,): for loss_masking in [True, False]: try: From 1d0df170ca5bf1d6b654c2df948432f145fd59e6 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 12 Dec 2025 23:30:21 +0000 Subject: [PATCH 06/19] comment --- fast_llm/functional/cross_entropy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8bc56349..484dfb39 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -288,6 +288,7 @@ def _reverse_kl_forward_backward( loss /= valid_tokens if grad_output is not None: + # need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005 log_ratio = student_log_probs - teacher_log_probs expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True) grad_base = torch.exp(student_log_probs) * (log_ratio - expected) From 9ae4e73c20407bf0451a8cc325f7823a74be951f Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 12 Dec 2025 23:32:44 +0000 Subject: [PATCH 07/19] clean --- fast_llm/functional/cross_entropy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 484dfb39..223e9037 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -58,13 +58,11 @@ def _fused_softmax_base( logits *= logits_scale_factor logits_max = torch.max(logits, dim=dim, keepdim=True)[0] if group is not None: - # Use autograd-aware all_reduce with correct gradient behavior all_reduce(logits_max, op=ReduceOp.MAX, group=group) logits_norm = (logits - logits_max).float() exp_logits = logits_norm.exp() sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True) if group is not None: - # Use autograd-aware all_reduce with correct gradient behavior all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) return logits_norm, exp_logits, sum_exp_logits From 3a3d06e6c1192a5cac3921abe972e70427a5c382 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 12 Dec 2025 23:55:28 +0000 Subject: [PATCH 08/19] tests --- tests/utils/dataset.py | 14 ++++++++++++-- tests/utils/model_configs.py | 4 ++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index be44ae61..27c8bdfb 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -280,15 +280,25 @@ def get_split_sharded_test_dataset() -> ( ) -def get_test_dataset_with_loss_masking_spans( +def get_dataset_with_loss_masking_spans( config_only: bool = False, ) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig]: return _get_test_dataset( DATASET_CACHE / "dataset_with_loss_masking_spans", seed=1234, - max_vocab_size=MODEL_TEST_VOCAB_SIZE, max_loss_masking_spans=5, + config_only=config_only, splits={"training": 969, "validation": 30, "test": 1}, + ) + + +def get_test_dataset_with_loss_masking_spans( + config_only: bool = False, +) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig]: + return _get_test_dataset( + DATASET_CACHE / "dataset_with_loss_masking_spans", + seed=1234, + max_loss_masking_spans=5, config_only=config_only, ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index e42be710..2c6a88ce 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -27,9 +27,9 @@ ) from fast_llm.models.multimodal.conversion.config import Apriel2CheckpointFormat, LlavaCheckpointFormat from tests.utils.dataset import ( + get_dataset_with_loss_masking_spans, get_model_test_dataset, get_multimodal_test_dataset, - get_test_dataset_with_loss_masking_spans, ) from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import MODEL_TEST_VOCAB_SIZE @@ -423,7 +423,7 @@ def _update_and_add_testing_config( ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, compare_factor=1.5, # Loss masking seem to induce slight numerical variation between dtypes - get_dataset=get_test_dataset_with_loss_masking_spans, + get_dataset=get_dataset_with_loss_masking_spans, ) _update_and_add_testing_config( From bc2c525e00a306e116f60b7080786d6706afa020 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 13 Dec 2025 00:06:08 +0000 Subject: [PATCH 09/19] test device --- tests/functional/test_cross_entropy.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index b4ea6964..c5214df5 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -14,19 +14,19 @@ def _get_cross_entropy_inputs( - num_columns: int, loss_masking: bool, target_format: TargetFormat + num_columns: int, loss_masking: bool, target_format: TargetFormat, device="cuda" ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # We want something moderately close to the target for the test to be meaningful - logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") / 3 - loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device="cuda") if loss_masking else None + logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device=device) / 3 + loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device=device) if loss_masking else None if target_format == TargetFormat.labels: - target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device="cuda") + target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device=device) logits = torch.nn.functional.one_hot(target, num_columns) + logits_var if loss_masking: logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) loss_mask = None else: - target = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") + target = torch.randn(256, num_columns, dtype=torch.bfloat16, device=device) logits = target + logits_var if target_format == TargetFormat.probabilities: target = torch.softmax(target, -1) @@ -115,7 +115,9 @@ def _reverse_kl_forward_backward_torch(target: torch.Tensor, logits: torch.Tenso @pytest.mark.parametrize("loss_masking", [False, True]) @pytest.mark.parametrize("target_format", (TargetFormat.logits,)) def test_reverse_kl(loss_masking, target_format): - logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + logits, target, loss_mask = _get_cross_entropy_inputs( + 1000, loss_masking, target_format, device="cuda" if torch.cuda.is_available() else "cpu" + ) out_ref, grad_ref = _reverse_kl_forward_backward_torch(target, logits, loss_mask) out, grad = reverse_kl_forward_backward( logits=logits, @@ -124,7 +126,6 @@ def test_reverse_kl(loss_masking, target_format): grad_output=1.0, target_format=TargetFormat.logits, ) - # TODO: Error looks _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) @@ -167,7 +168,9 @@ def _compare_parallel_cross_entropy( # Ensure all workers have the same inputs. torch.manual_seed(0) world_size = torch.distributed.get_world_size(group) - logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + logits, target, loss_mask = _get_cross_entropy_inputs( + 1000, loss_masking, target_format, device="cuda" if torch.cuda.is_available() else "cpu" + ) out, grad = function( logits=logits.chunk(world_size, 1)[rank], From 44c5f63a7969cf501d5d6150d0d1efbc9ea8f7c0 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 13 Dec 2025 16:47:28 +0000 Subject: [PATCH 10/19] grad fix --- fast_llm/functional/cross_entropy.py | 7 +++++-- tests/functional/test_cross_entropy.py | 7 +++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 223e9037..44cf2114 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -262,8 +262,8 @@ def _reverse_kl_forward_backward( # Compute log probabilities teacher_log_probs = distributed_log_softmax(target.float(), group=group) with torch.enable_grad(): - logits_ = logits.float().detach().requires_grad_(grad_output is not None) - student_log_probs = distributed_log_softmax(logits_, group=group) + # logits_ = logits.float()#.detach().requires_grad_(grad_output is not None) + student_log_probs = distributed_log_softmax(logits, group=group) # Reverse KL: input=teacher_log_probs, target=student_probs loss_terms = torch.nn.functional.kl_div( @@ -289,6 +289,9 @@ def _reverse_kl_forward_backward( # need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005 log_ratio = student_log_probs - teacher_log_probs expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True) + # expected E_q(log s - log t) -- this is actually dependent on the full vocab! + if group is not None: + all_reduce(expected, op=ReduceOp.SUM, group=group) grad_base = torch.exp(student_log_probs) * (log_ratio - expected) # Apply mask to gradients, not to log_probs! diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index c5214df5..ebd3402c 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -40,11 +40,10 @@ def _compare_cross_entropy_outputs( grad: torch.Tensor | None, ref_grad: torch.Tensor | None, threshold=1e-5, - min_threshold_grads=1e-8, ): Assert.rms_close_relative(loss, ref_loss, threshold, 1e-6) if has_grad: - Assert.rms_close_relative(grad, ref_grad, threshold, min_threshold_grads) + Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) else: assert grad is None assert ref_grad is None @@ -188,14 +187,14 @@ def _compare_parallel_cross_entropy( grad_output=1, target_format=target_format, ) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4, 1e-6) + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): success = True for function in (reverse_kl_forward_backward, cross_entropy_forward_backward): for target_format in (TargetFormat.logits,): - for loss_masking in [True, False]: + for loss_masking in [False, True]: try: _compare_parallel_cross_entropy(rank, group, target_format, function, loss_masking) except Exception: From 0111e9f1a20ac9f1024d11f2e16fb3e02a09400b Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 13 Dec 2025 16:59:03 +0000 Subject: [PATCH 11/19] fixes --- fast_llm/models/gpt/model.py | 5 ++++- tests/utils/dataset.py | 17 +++-------------- tests/utils/model_configs.py | 9 +++------ 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index fd8d2af1..32eaf8c3 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -247,7 +247,10 @@ def preprocess_batch( for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): for begin, end in loss_masking_spans: loss_mask[sample_index, begin:end] = False - if self._config.head.distillation_model is not None: + if ( + self._config.head.distillation_model is not None + and self._config.decoder.block.distillation_model is not None + ): kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 27c8bdfb..b2b5db0d 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -226,7 +226,7 @@ def _get_test_dataset( preparator_config.run() config = ( - {"type": "file", "path": config_paths[0]} # TODO: shouldn't this be {"training": {...}}? + {"type": "file", "path": config_paths[0]} if splits is None else { split: {"type": "file", "path": config_path} @@ -280,18 +280,6 @@ def get_split_sharded_test_dataset() -> ( ) -def get_dataset_with_loss_masking_spans( - config_only: bool = False, -) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig]: - return _get_test_dataset( - DATASET_CACHE / "dataset_with_loss_masking_spans", - seed=1234, - max_loss_masking_spans=5, - config_only=config_only, - splits={"training": 969, "validation": 30, "test": 1}, - ) - - def get_test_dataset_with_loss_masking_spans( config_only: bool = False, ) -> tuple[pathlib.Path, dict[str, typing.Any], pathlib.Path, LanguageModelPreprocessingConfig]: @@ -330,10 +318,11 @@ def get_test_dataset_with_image_patches( ) -def get_model_test_dataset(config_only: bool = False): +def get_model_test_dataset(config_only: bool = False, use_loss_masking: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset", seed=1234, + max_loss_masking_spans=5 if use_loss_masking else 0, max_vocab_size=MODEL_TEST_VOCAB_SIZE, splits={"training": 969, "validation": 30, "test": 1}, config_only=config_only, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 2c6a88ce..99356d41 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -6,6 +6,7 @@ import pathlib import re import typing +from functools import partial import pytest import transformers @@ -26,11 +27,7 @@ Qwen2CheckpointFormat, ) from fast_llm.models.multimodal.conversion.config import Apriel2CheckpointFormat, LlavaCheckpointFormat -from tests.utils.dataset import ( - get_dataset_with_loss_masking_spans, - get_model_test_dataset, - get_multimodal_test_dataset, -) +from tests.utils.dataset import get_model_test_dataset, get_multimodal_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import MODEL_TEST_VOCAB_SIZE @@ -423,7 +420,7 @@ def _update_and_add_testing_config( ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, compare_factor=1.5, # Loss masking seem to induce slight numerical variation between dtypes - get_dataset=get_dataset_with_loss_masking_spans, + get_dataset=partial(get_model_test_dataset, use_loss_masking=True), ) _update_and_add_testing_config( From f6238c09058d4854bd9c9d62a1d80b33fa4040b3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 13 Dec 2025 17:04:14 +0000 Subject: [PATCH 12/19] clean --- fast_llm/functional/cross_entropy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 44cf2114..7e60d211 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -294,7 +294,6 @@ def _reverse_kl_forward_backward( all_reduce(expected, op=ReduceOp.SUM, group=group) grad_base = torch.exp(student_log_probs) * (log_ratio - expected) - # Apply mask to gradients, not to log_probs! if loss_mask is not None: valid = loss_mask.to(logits.dtype).unsqueeze(-1) grad_base = grad_base * valid From f28b241b77f59ed3ea9d5c3e2a627b91729db7da Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 13 Dec 2025 17:05:36 +0000 Subject: [PATCH 13/19] clean --- tests/layers/test_varlen.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index 54f03958..a59e0542 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -100,7 +100,3 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): for name, parameter, grad_packed in zip(names, parameters, grads_packed, strict=True): Assert.rms_close_relative(grad_packed, parameter.grad_buffer, 1e-3, 1e-4, msg=name) - - -if __name__ == "__main__": - pytest.main([__file__]) From e41c040c826c2bac610ccebb3ac32f8bd3e5f366 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 13 Dec 2025 17:13:46 +0000 Subject: [PATCH 14/19] nvm --- tests/layers/test_varlen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index a59e0542..c8d962f4 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -10,7 +10,7 @@ from fast_llm.layers.decoder.config import MixerConfig from fast_llm.layers.ssm import gdn as gdn_module from fast_llm.layers.ssm import kda as kda_module -from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, Mamba2Config +from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.utils import Assert from tests.utils.utils import get_stage, requires_cuda @@ -23,7 +23,7 @@ [ AttentionConfig(heads=4, head_groups=2, head_size=16, cross_document_attention=False), pytest.param( - Mamba2Config( + MambaConfig( d_inner=128, d_xb=64, state_size=16, From b6e8775fd80754822ce57a4977a3edd4a0135244 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 15 Dec 2025 13:51:25 +0000 Subject: [PATCH 15/19] clean warnings --- fast_llm/functional/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 7e60d211..cffb88d1 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -276,7 +276,7 @@ def _reverse_kl_forward_backward( # loss mask is the same on all ranks for TP over vocab. valid = loss_mask.to(loss_terms.dtype) loss_terms = loss_terms * valid - valid_tokens = torch.tensor(valid.sum(), device=loss_terms.device, dtype=loss_terms.dtype) + valid_tokens = valid.sum() else: valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) loss = loss_terms.sum() # sums over batch and seq. len. From dcd55a52526c6ce42181c8f47dc2f0a0e96b5cf4 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 16 Dec 2025 14:35:26 +0000 Subject: [PATCH 16/19] clean up --- fast_llm/functional/cross_entropy.py | 76 +++++++++++++------------- tests/functional/test_cross_entropy.py | 8 +-- tests/utils/dataset.py | 4 +- tests/utils/model_configs.py | 2 - 4 files changed, 41 insertions(+), 49 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index cffb88d1..8c9ea939 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -261,47 +261,45 @@ def _reverse_kl_forward_backward( # Compute log probabilities teacher_log_probs = distributed_log_softmax(target.float(), group=group) - with torch.enable_grad(): - # logits_ = logits.float()#.detach().requires_grad_(grad_output is not None) - student_log_probs = distributed_log_softmax(logits, group=group) - - # Reverse KL: input=teacher_log_probs, target=student_probs - loss_terms = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="none", - log_target=True, - ).sum(dim=-1) - if loss_mask is not None: - # loss mask is the same on all ranks for TP over vocab. - valid = loss_mask.to(loss_terms.dtype) - loss_terms = loss_terms * valid - valid_tokens = valid.sum() - else: - valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) - loss = loss_terms.sum() # sums over batch and seq. len. + student_log_probs = distributed_log_softmax(logits, group=group) + + # Reverse KL: input=teacher_log_probs, target=student_probs + loss_terms = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="none", + log_target=True, + ).sum(dim=-1) + if loss_mask is not None: + # loss mask is the same on all ranks for TP over vocab. + valid = loss_mask.to(loss_terms.dtype) + loss_terms = loss_terms * valid + valid_tokens = valid.sum() + else: + valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) + loss = loss_terms.sum() # sums over batch and seq. len. + if group is not None: + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= valid_tokens + + if grad_output is not None: + # need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005 + log_ratio = student_log_probs - teacher_log_probs + expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True) + # expected E_q(log s - log t) -- this is actually dependent on the full vocab! if group is not None: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= valid_tokens - - if grad_output is not None: - # need to calculate gradient manually, backprop through all reduce can be problematic, see https://github.com/pytorch/pytorch/issues/58005 - log_ratio = student_log_probs - teacher_log_probs - expected = torch.sum(torch.exp(student_log_probs) * log_ratio, dim=-1, keepdim=True) - # expected E_q(log s - log t) -- this is actually dependent on the full vocab! - if group is not None: - all_reduce(expected, op=ReduceOp.SUM, group=group) - grad_base = torch.exp(student_log_probs) * (log_ratio - expected) - - if loss_mask is not None: - valid = loss_mask.to(logits.dtype).unsqueeze(-1) - grad_base = grad_base * valid - - grad = grad_base.mul(grad_output / valid_tokens) - grad = grad.to(logits.dtype) - else: - grad = None + all_reduce(expected, op=ReduceOp.SUM, group=group) + grad_base = torch.exp(student_log_probs) * (log_ratio - expected) + + if loss_mask is not None: + valid = loss_mask.to(logits.dtype).unsqueeze(-1) + grad_base = grad_base * valid + + grad = grad_base.mul(grad_output / valid_tokens) + grad = grad.to(logits.dtype) + else: + grad = None return loss.detach_(), grad diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 8f2e3def..afac1296 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -115,9 +115,7 @@ def _reverse_kl_forward_backward_torch(target: torch.Tensor, logits: torch.Tenso @pytest.mark.parametrize("loss_masking", [False, True]) @pytest.mark.parametrize("target_format", (TargetFormat.logits,)) def test_reverse_kl(loss_masking, target_format): - logits, target, loss_mask = _get_cross_entropy_inputs( - 1000, loss_masking, target_format, device="cuda" if torch.cuda.is_available() else "cpu" - ) + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) out_ref, grad_ref = _reverse_kl_forward_backward_torch(target, logits, loss_mask) out, grad = reverse_kl_forward_backward( logits=logits, @@ -168,9 +166,7 @@ def _compare_parallel_cross_entropy( # Ensure all workers have the same inputs. torch.manual_seed(0) world_size = torch.distributed.get_world_size(group) - logits, target, loss_mask = _get_cross_entropy_inputs( - 1000, loss_masking, target_format, device="cuda" if torch.cuda.is_available() else "cpu" - ) + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) out, grad = function( logits=logits.chunk(world_size, 1)[rank], diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index b2b5db0d..854ecec3 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -318,11 +318,11 @@ def get_test_dataset_with_image_patches( ) -def get_model_test_dataset(config_only: bool = False, use_loss_masking: bool = False): +def get_model_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset", seed=1234, - max_loss_masking_spans=5 if use_loss_masking else 0, + max_loss_masking_spans=5, max_vocab_size=MODEL_TEST_VOCAB_SIZE, splits={"training": 969, "validation": 30, "test": 1}, config_only=config_only, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index ec0cbe07..b671059b 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -6,7 +6,6 @@ import pathlib import re import typing -from functools import partial import pytest import transformers @@ -420,7 +419,6 @@ def _update_and_add_testing_config( ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, compare_factor=1.5, # Loss masking seem to induce slight numerical variation between dtypes - get_dataset=partial(get_model_test_dataset, use_loss_masking=True), ) _update_and_add_testing_config( From 6fef1fb2ac7e7da44caff83dbf43bf3a27109b48 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 16 Dec 2025 17:22:25 +0000 Subject: [PATCH 17/19] loss mask transposition was missing --- fast_llm/models/gpt/model.py | 6 +++++- tests/functional/test_cross_entropy.py | 4 ++-- tests/utils/model_configs.py | 18 +----------------- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 32eaf8c3..64e7f1cb 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -249,7 +249,7 @@ def preprocess_batch( loss_mask[sample_index, begin:end] = False if ( self._config.head.distillation_model is not None - and self._config.decoder.block.distillation_model is not None + or self._config.decoder.block.distillation_model is not None ): kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) @@ -257,6 +257,10 @@ def preprocess_batch( kwargs[LanguageModelKwargs.labels] = ( labels.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] else labels ).contiguous() + if LanguageModelKwargs.loss_mask in kwargs and kwargs[AttentionKwargs.sequence_first]: + kwargs[LanguageModelKwargs.loss_mask] = ( + kwargs[LanguageModelKwargs.loss_mask].transpose(0, 1).contiguous() + ) if batch.chosen_spans is not None: kwargs[LanguageModelKwargs.chosen_spans] = batch.chosen_spans.crop(labels_begin, labels_end).ranges diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index afac1296..72644d06 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -95,7 +95,7 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski ) -def _reverse_kl_forward_backward_torch(target: torch.Tensor, logits: torch.Tensor, loss_mask: torch.Tensor | None): +def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None): # Manual reference: sum over vocab then average over valid tokens. logits = logits.detach().requires_grad_() per_sample = torch.nn.functional.kl_div( @@ -116,7 +116,7 @@ def _reverse_kl_forward_backward_torch(target: torch.Tensor, logits: torch.Tenso @pytest.mark.parametrize("target_format", (TargetFormat.logits,)) def test_reverse_kl(loss_masking, target_format): logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) - out_ref, grad_ref = _reverse_kl_forward_backward_torch(target, logits, loss_mask) + out_ref, grad_ref = _reverse_kl_forward_backward_torch(logits, target, loss_mask) out, grad = reverse_kl_forward_backward( logits=logits, target=target, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index b671059b..6156cb70 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -404,23 +404,6 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( - "llama", - "llama_with_loss_masking", - updates={ - ("batch", "use_loss_masking_spans"): True, - }, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, - ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, - ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, - ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, - ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, - }, - compare_factor=1.5, # Loss masking seem to induce slight numerical variation between dtypes -) - _update_and_add_testing_config( # Tests yarn-style rotary embeddings. "llama", @@ -569,6 +552,7 @@ def _update_and_add_testing_config( "mistral_distill_logits", updates={ ("model", "base_model", "head", "distillation_model"): "teacher", + ("batch", "use_loss_masking_spans"): True, ("reference_models"): { "teacher": { "model": {"base_model": copy.deepcopy(_mistral_base_model)}, From 8da6f108396b36618c01f9968f7ad35f4ecef6a2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 16 Dec 2025 20:35:17 +0000 Subject: [PATCH 18/19] loss masking fixes: cross entropy averaging & training with minibatches --- fast_llm/data/sample/language_model.py | 25 ++++++++++++- fast_llm/engine/base_model/base_model.py | 1 + fast_llm/engine/schedule/runner.py | 41 +++++++++++++++++++-- fast_llm/functional/cross_entropy.py | 16 ++++---- fast_llm/functional/triton/cross_entropy.py | 15 ++++++-- fast_llm/layers/language_model/config.py | 1 + fast_llm/models/gpt/model.py | 4 ++ fast_llm/models/multimodal/model.py | 8 +++- tests/utils/model_configs.py | 6 ++- 9 files changed, 98 insertions(+), 19 deletions(-) diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 3183a9ec..25eb249b 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -98,21 +98,41 @@ def __init__( chosen_spans: RangeBatch | None = None, rejected_spans: RangeBatch | None = None, image_patches: PatchBatch | None = None, + valid_tokens: int | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans self.chosen_spans = chosen_spans self.rejected_spans = rejected_spans self.image_patches = image_patches + self.valid_tokens = valid_tokens @classmethod def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: + samples = list(samples) + token_batch = TokenBatch.from_samples([sample.tokens for sample in samples]) + loss_masking_spans = _merge_optional( + RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples] + ) + + # Calculate valid tokens for this batch (used for gradient accumulation weighting) + valid_tokens = None + if loss_masking_spans is not None: + batch_size, sequence_length = token_batch.tokens.shape + # Start with all tokens + valid_tokens = batch_size * sequence_length + # Subtract masked tokens + for sample_ranges in loss_masking_spans.ranges: + for begin, end in sample_ranges: + valid_tokens -= end - begin + return cls( - TokenBatch.from_samples([sample.tokens for sample in samples]), - _merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]), + token_batch, + loss_masking_spans, _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), _merge_optional(PatchBatch.from_samples, [sample.image_patches for sample in samples]), + valid_tokens, ) def crop(self, begin: int, end: int) -> typing.Self: @@ -122,6 +142,7 @@ def crop(self, begin: int, end: int) -> typing.Self: _crop_optional(self.chosen_spans, begin, end), _crop_optional(self.rejected_spans, begin, end), _crop_optional(self.image_patches, begin, end), + valid_tokens=None, # Cropped batches don't have valid token counts ) def to_device_(self, device: "torch.device | str"): diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index ffffbed5..e41b686d 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -179,6 +179,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + total_valid_tokens: int | None = None, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 133b3206..5078bf4c 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -10,6 +10,7 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import all_reduce, recv, safe_barrier, send +from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.run import get_run, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed @@ -19,6 +20,7 @@ from fast_llm.engine.schedule.config import EventType, ScheduleConfig, StepType, StreamType from fast_llm.engine.schedule.schedule import Schedule, Step from fast_llm.logging import log_memory_usage +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -319,10 +321,31 @@ def _train_step(self, context: BatchContext, step: Step) -> None: def _preprocess_data( self, context: BatchContext, data_iterator: typing.Iterator, preprocessed: bool ) -> typing.Generator[None, None, None]: - batch_config = context.schedule.batch_config - grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) / batch_config.num_inputs + from fast_llm.layers.language_model.config import LanguageModelKwargs + + batch_config: GPTBatchConfig = context.schedule.batch_config + default_grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) / batch_config.num_inputs + + # We need additional pass to compute total valid tokens, which is needed to correctly set grad weights when using loss masks + grad accumulation + # TODO: add conditions? This must not be used always + all_micro_batches = [] + total_valid_tokens = None for micro_batch in range(batch_config.sequential_micro_batches): - micro_batch_data = next(data_iterator) + micro_batch_data: LanguageModelBatch = next(data_iterator) + all_micro_batches.append(micro_batch_data) + + # Sum valid tokens across all microbatches (if loss masking is used) + if ( + not preprocessed + and hasattr(micro_batch_data, "valid_tokens") + and micro_batch_data.valid_tokens is not None + ): + if total_valid_tokens is None: + total_valid_tokens = 0 + total_valid_tokens += micro_batch_data.valid_tokens + + # Second pass: Preprocess and yield each microbatch with correct gradient weighting + for micro_batch, micro_batch_data in enumerate(all_micro_batches): if not preprocessed: micro_batch_data = self._multi_stage.base_model.preprocess_batch( micro_batch_data, @@ -330,8 +353,20 @@ def _preprocess_data( phase=context.phase, iteration=context.iteration, metrics=context.metrics, + total_valid_tokens=total_valid_tokens, ) for micro_batch_split, (input_, kwargs) in enumerate(micro_batch_data): + # Compute grad_output based on valid tokens when loss masking is used + if LanguageModelKwargs.loss_mask in kwargs and total_valid_tokens is not None: + loss_mask = kwargs[LanguageModelKwargs.loss_mask] + valid_tokens = loss_mask.sum().item() + # Weight this micro-batch by its proportion of valid tokens. This is required to correctly scale the gradients when different microbatches have different number of valid tokens + grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) * ( + valid_tokens / total_valid_tokens + ) + else: + grad_output = default_grad_output + kwargs.update( grad_output=grad_output, micro_batch=micro_batch, diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8c9ea939..1123ed5d 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -35,12 +35,10 @@ def _torch_cross_entropy_forward_backward( logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target ) else: - loss = ( - torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" - ) - * loss_mask - ).mean() + per_sample_loss = torch.nn.functional.cross_entropy( + logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" + ) + loss = (per_sample_loss * loss_mask).sum() / loss_mask.sum() if grad_output is None: grad = None else: @@ -129,7 +127,8 @@ def _fused_cross_entropy_forward_backward( else: grad_base = exp_logits - sum_exp_logits * target - grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits) + normalizer = loss_mask.sum() if loss_mask is not None else logits.size(0) + grad = grad_base.mul((grad_output / normalizer) / sum_exp_logits) if logits_scale_factor != 1.0: grad *= logits_scale_factor if loss_mask is not None: @@ -155,7 +154,8 @@ def _fused_cross_entropy_forward_backward( if loss_mask is not None: per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() + valid_tokens = loss_mask.sum() if loss_mask is not None else logits.size(0) + loss = per_sample_loss.sum() / valid_tokens if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.AVG, group=group) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 295cdb74..2348d9c3 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -144,13 +144,22 @@ def triton_cross_entropy_forward_backward( losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) # TODO: Safe to do inplace? grad_logits = None if grad_output is None else torch.empty_like(logits) + + # Compute valid token count for loss masking + if target_format == TargetFormat.labels: + # For labels format, masking is done via negative labels + valid_count = (target >= 0).sum().item() # Convert to Python scalar + else: + # For logits/probabilities format, masking is done via loss_mask + valid_count = loss_mask.sum().item() if loss_mask is not None else n_rows + if target_format == TargetFormat.labels: triton_cross_entropy_forward_backward_kernel[(n_rows,)]( logits, target, grad_logits, losses, - None if grad_output is None else grad_output / n_rows, + None if grad_output is None else grad_output / valid_count, n_cols, logits.stride(0), None if grad_output is None else grad_logits.stride(0), @@ -167,7 +176,7 @@ def triton_cross_entropy_forward_backward( loss_mask, grad_logits, losses, - None if grad_output is None else grad_output / n_rows, + None if grad_output is None else grad_output / valid_count, n_cols, logits.stride(0), target.stride(0), @@ -177,4 +186,4 @@ def triton_cross_entropy_forward_backward( num_warps=num_warps, from_logits=target_format == TargetFormat.logits, ) - return losses.mean(), grad_logits + return losses.sum() / valid_count, grad_logits diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 53dac289..873d3339 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -31,6 +31,7 @@ class LanguageModelKwargs(BlockKwargs): chosen_spans = "chosen_spans" rejected_spans = "rejected_spans" loss_mask = "loss_mask" + total_valid_tokens = "total_valid_tokens" mask_inputs = "mask_inputs" diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 64e7f1cb..944ac1ab 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -158,6 +158,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + total_valid_tokens: int | None = None, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup @@ -252,6 +253,9 @@ def preprocess_batch( or self._config.decoder.block.distillation_model is not None ): kwargs[LanguageModelKwargs.loss_mask] = loss_mask + # Pass total_valid_tokens for correct gradient accumulation + if total_valid_tokens is not None: + kwargs[LanguageModelKwargs.total_valid_tokens] = total_valid_tokens labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = ( diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 890d5760..6cb18f74 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -159,9 +159,15 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + total_valid_tokens: int | None = None, ) -> list[tuple[torch.Tensor, dict]]: preprocessed = super().preprocess_batch( - batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics + batch, + preprocessed_meta, + phase=phase, + iteration=iteration, + metrics=metrics, + total_valid_tokens=total_valid_tokens, ) # TODO: Support micro-sequences. assert len(preprocessed) == 1, "Micro-sequences not supported for MultiModalModel." diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6156cb70..62ca454c 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -571,7 +571,8 @@ def _update_and_add_testing_config( }, compare_factor=1.5, # modes not supported with reference models - skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), + # TODO: ce4: cross_entropy_splits is broken, skipping it for nwo since its low priority and almost never used + skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2", "ce4"), ) _update_and_add_testing_config( @@ -592,7 +593,8 @@ def _update_and_add_testing_config( }, compare_factor=2, # Modes not supported with reference models - skip_tests=("sdp", "ms", "pp"), + # TODO: ce4: cross_entropy_splits is broken, skipping it for nwo since its low priority and almost never used + skip_tests=("sdp", "ms", "pp", "ce4"), ) _update_and_add_testing_config( From e2032f5f5ea592a3a2559376588a99dbc65ac6dc Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 16 Dec 2025 23:13:43 +0000 Subject: [PATCH 19/19] added loss comparison --- fast_llm/engine/multi_stage/config.py | 6 ++++++ fast_llm/engine/schedule/runner.py | 6 +++++- fast_llm/layers/language_model/head.py | 18 ++++++++++++++++-- tests/utils/distributed_configs.py | 3 +++ tests/utils/model_configs.py | 5 +++-- 5 files changed, 33 insertions(+), 5 deletions(-) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 41736aed..733ffc5f 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -115,6 +115,12 @@ class StageConfig(Config): hint=FieldHint.logging, valid=check_field(Assert.geq, 0), ) + debug_losses: int = Field( + default=0, + desc="Log loss values after reduction.", + hint=FieldHint.logging, + valid=check_field(Assert.geq, 0), + ) debug_param_update: int = Field( default=0, desc="Log the parameters after update.", diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 5078bf4c..9be1ae41 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -19,7 +19,7 @@ from fast_llm.engine.optimizer.optimizer import Optimizer from fast_llm.engine.schedule.config import EventType, ScheduleConfig, StepType, StreamType from fast_llm.engine.schedule.schedule import Schedule, Step -from fast_llm.logging import log_memory_usage +from fast_llm.logging import log_memory_usage, log_tensor from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert @@ -297,6 +297,10 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: else: reduced_loss = 0.0 reduced_losses[name] = reduced_loss + if isinstance(reduced_loss, torch.Tensor) and self._multi_stage.config.multi_stage.debug_losses: + log_tensor( + f"loss: {name}", reduced_loss, level=self._multi_stage.config.multi_stage.debug_losses, log_fn=None + ) return { name: reduced_loss.item() if isinstance(reduced_loss, torch.Tensor) else reduced_loss for name, reduced_loss in reduced_losses.items() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b1d0c2ac..ba11ca4a 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -375,6 +375,21 @@ def _logits_cross_entropy_forward_backward( lm_loss, lm_grad = None, None if distillation_target is not None and self._config.distillation_loss_factor > 0.0: + # We need to scale the loss by (valid_tokens * num_micro_batches) / total_valid_tokens to correctly average the loss over micro-batches. + # The runner averages losses by dividing by num_micro_batches, so we need to account for that. + # Note: for grads this scaling is already in the 'grad_output' + total_valid_tokens = kwargs.get( + LanguageModelKwargs.total_valid_tokens + ) # number of not masked tokens across all micro-batches. + num_micro_batches = kwargs.get("num_micro_batches", 1) + + if loss_mask is None or total_valid_tokens is None: + loss_scalor_df = 1 + else: + valid_tokens = loss_mask.sum() + # Scale by (valid_tokens * num_micro_batches) / total_valid_tokens + # This accounts for the runner dividing by num_micro_batches + loss_scalor_df = (valid_tokens * num_micro_batches) / total_valid_tokens if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), @@ -405,13 +420,12 @@ def _logits_cross_entropy_forward_backward( raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) - distillation_loss = distillation_loss * self._config.distillation_loss_factor + distillation_loss = distillation_loss * self._config.distillation_loss_factor * loss_scalor_df else: distillation_loss, distillation_grad = None, None # TODO: de-allocate earlier. del logits - # TODO: Accumulate grads in-place to reduce memory and compute overhead. grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 83ed6836..ce41d104 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -38,6 +38,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Biases have higher absolute error. (None, "bias"): get_config(3e-3, 5e-5), (None, "gradient"): get_config(3e-3, 3e-5), + (None, "loss"): get_config(1e-5, 1e-6), } ) @@ -60,6 +61,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon (None, "bw"): get_config(1.5e-2, 1e-5), (None, "bias"): get_config(2e-2, 1e-3), (None, "gradient"): get_config(2e-2, 5e-5), + (None, "loss"): get_config(2e-4, 2e-4), } ) @@ -71,6 +73,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon (None, "bw"): get_config(3e-3, 1e-5, scale=2**16), (None, "bias"): get_config(3e-3, 1e-4, scale=2**16), (None, "gradient"): get_config(3e-3, 5e-5, scale=2**16), + (None, "loss"): get_config(1e-4, 1e-4), } ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 62ca454c..2ffd7788 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -249,6 +249,7 @@ def _update_and_add_testing_config( "debug_layer_outputs": _LOG_LEVEL, "debug_layer_gradients": _LOG_LEVEL, "debug_all_param_gradients": _LOG_LEVEL, + "debug_losses": _LOG_LEVEL, "debug_tensor_parallel": True, }, "distributed": { @@ -571,7 +572,7 @@ def _update_and_add_testing_config( }, compare_factor=1.5, # modes not supported with reference models - # TODO: ce4: cross_entropy_splits is broken, skipping it for nwo since its low priority and almost never used + # TODO: ce4: cross_entropy_splits is broken, skipping it for now since its low priority and almost never used skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2", "ce4"), ) @@ -593,7 +594,7 @@ def _update_and_add_testing_config( }, compare_factor=2, # Modes not supported with reference models - # TODO: ce4: cross_entropy_splits is broken, skipping it for nwo since its low priority and almost never used + # TODO: ce4: cross_entropy_splits is broken, skipping it for now since its low priority and almost never used skip_tests=("sdp", "ms", "pp", "ce4"), )