Skip to content

Commit 8096ed4

Browse files
committed
Restore Megatron SFT loss scaling before backward
1 parent 3d8d1f5 commit 8096ed4

1 file changed

Lines changed: 2 additions & 9 deletions

File tree

src/art/megatron/train.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,6 @@ def run_megatron_sft_job(
484484
chunk.zero_grad_buffer() # type: ignore[call-non-callable]
485485

486486
batch_loss = torch.tensor(0.0, device=device)
487-
local_trainable_tokens = 0.0
488487
for param_group in runtime.optimizer.param_groups:
489488
param_group["lr"] = job.learning_rates[batch_idx]
490489

@@ -499,7 +498,6 @@ def run_megatron_sft_job(
499498
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
500499
shifted_labels = shift_tensor(labels, -100)
501500
mask = shifted_labels != -100
502-
local_trainable_tokens += float(mask.sum().item())
503501

504502
per_token_loss: torch.Tensor = runtime.model[0](
505503
input_ids=input_ids,
@@ -511,15 +509,10 @@ def run_megatron_sft_job(
511509
},
512510
)
513511
masked_loss = per_token_loss[mask].sum()
514-
masked_loss.backward()
512+
(masked_loss / float(global_trainable_tokens)).backward()
515513
batch_loss += masked_loss.detach()
516514

517-
num_tokens = torch.tensor(
518-
[local_trainable_tokens],
519-
device=device,
520-
dtype=torch.float32,
521-
)
522-
finalize_model_grads_extended(runtime.model, num_tokens=num_tokens)
515+
finalize_model_grads_extended(runtime.model)
523516
update_successful, grad_norm, num_zeros_in_grad = runtime.optimizer.step()
524517
runtime.optimizer.zero_grad()
525518
del update_successful, num_zeros_in_grad

0 commit comments

Comments
 (0)