Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pithtrain/tasks/pretrain_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,12 +422,15 @@ def train_step(cfg: PretrainLanguageModelCfg, ctx: PretrainLanguageModelCtx) ->
torch.distributed.all_reduce(peak_gpu_mem, op=torch.distributed.ReduceOp.MAX)

# Collect the mean load balance loss (reduced across all ranks).
# The tracked values include the coefficient: lb_coef * E * dot(f, p).
# We divide it out so the logged metric is E * dot(f, p), where 1.0
# represents perfect balance (matches Megatron-LM convention).
moe_load_balance_coef = cfg.training.moe_load_balance_coef
lb_total, lb_count = MoELoadBalanceLossTracker.get_total_count_and_clear()
if moe_load_balance_coef > 0:
lb_stats = torch.tensor([lb_total, lb_count], device=device)
torch.distributed.all_reduce(lb_stats, op=torch.distributed.ReduceOp.SUM)
lb_loss = (lb_stats[0] / lb_stats[1]).item() if lb_stats[1] > 0 else 0.0
lb_loss = (lb_stats[0] / lb_stats[1]).item() / moe_load_balance_coef
else:
lb_loss = 0.0

Expand Down
Loading