From 67a0a65fd2d0ec800085a81e62b42eca0b508b74 Mon Sep 17 00:00:00 2001 From: haok1402 Date: Sun, 5 Apr 2026 11:16:45 -0400 Subject: [PATCH] Rescale the load balancing loss so the value of 1.0 represents perfect balance. --- pithtrain/tasks/pretrain_language_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pithtrain/tasks/pretrain_language_model.py b/pithtrain/tasks/pretrain_language_model.py index c740898..df7719e 100644 --- a/pithtrain/tasks/pretrain_language_model.py +++ b/pithtrain/tasks/pretrain_language_model.py @@ -422,12 +422,15 @@ def train_step(cfg: PretrainLanguageModelCfg, ctx: PretrainLanguageModelCtx) -> torch.distributed.all_reduce(peak_gpu_mem, op=torch.distributed.ReduceOp.MAX) # Collect the mean load balance loss (reduced across all ranks). + # The tracked values include the coefficient: lb_coef * E * dot(f, p). + # We divide it out so the logged metric is E * dot(f, p), where 1.0 + # represents perfect balance (matches Megatron-LM convention). moe_load_balance_coef = cfg.training.moe_load_balance_coef lb_total, lb_count = MoELoadBalanceLossTracker.get_total_count_and_clear() if moe_load_balance_coef > 0: lb_stats = torch.tensor([lb_total, lb_count], device=device) torch.distributed.all_reduce(lb_stats, op=torch.distributed.ReduceOp.SUM) - lb_loss = (lb_stats[0] / lb_stats[1]).item() if lb_stats[1] > 0 else 0.0 + lb_loss = (lb_stats[0] / lb_stats[1]).item() / moe_load_balance_coef else: lb_loss = 0.0