diff --git a/pithtrain/tasks/pretrain_language_model.py b/pithtrain/tasks/pretrain_language_model.py index c740898..df7719e 100644 --- a/pithtrain/tasks/pretrain_language_model.py +++ b/pithtrain/tasks/pretrain_language_model.py @@ -422,12 +422,15 @@ def train_step(cfg: PretrainLanguageModelCfg, ctx: PretrainLanguageModelCtx) -> torch.distributed.all_reduce(peak_gpu_mem, op=torch.distributed.ReduceOp.MAX) # Collect the mean load balance loss (reduced across all ranks). + # The tracked values include the coefficient: lb_coef * E * dot(f, p). + # We divide it out so the logged metric is E * dot(f, p), where 1.0 + # represents perfect balance (matches Megatron-LM convention). moe_load_balance_coef = cfg.training.moe_load_balance_coef lb_total, lb_count = MoELoadBalanceLossTracker.get_total_count_and_clear() if moe_load_balance_coef > 0: lb_stats = torch.tensor([lb_total, lb_count], device=device) torch.distributed.all_reduce(lb_stats, op=torch.distributed.ReduceOp.SUM) - lb_loss = (lb_stats[0] / lb_stats[1]).item() if lb_stats[1] > 0 else 0.0 + lb_loss = (lb_stats[0] / lb_stats[1]).item() / moe_load_balance_coef else: lb_loss = 0.0