@@ -484,7 +484,6 @@ def run_megatron_sft_job(
484484 chunk .zero_grad_buffer () # type: ignore[call-non-callable]
485485
486486 batch_loss = torch .tensor (0.0 , device = device )
487- local_trainable_tokens = 0.0
488487 for param_group in runtime .optimizer .param_groups :
489488 param_group ["lr" ] = job .learning_rates [batch_idx ]
490489
@@ -499,7 +498,6 @@ def run_megatron_sft_job(
499498 position_ids = torch .arange (seq_len , device = device ).unsqueeze (0 )
500499 shifted_labels = shift_tensor (labels , - 100 )
501500 mask = shifted_labels != - 100
502- local_trainable_tokens += float (mask .sum ().item ())
503501
504502 per_token_loss : torch .Tensor = runtime .model [0 ](
505503 input_ids = input_ids ,
@@ -511,15 +509,10 @@ def run_megatron_sft_job(
511509 },
512510 )
513511 masked_loss = per_token_loss [mask ].sum ()
514- masked_loss .backward ()
512+ ( masked_loss / float ( global_trainable_tokens )) .backward ()
515513 batch_loss += masked_loss .detach ()
516514
517- num_tokens = torch .tensor (
518- [local_trainable_tokens ],
519- device = device ,
520- dtype = torch .float32 ,
521- )
522- finalize_model_grads_extended (runtime .model , num_tokens = num_tokens )
515+ finalize_model_grads_extended (runtime .model )
523516 update_successful , grad_norm , num_zeros_in_grad = runtime .optimizer .step ()
524517 runtime .optimizer .zero_grad ()
525518 del update_successful , num_zeros_in_grad
0 commit comments