Skip to content

Commit 4609f97

Browse files
authored
change loss return format so that it can work with calculate_per_token_loss (NVIDIA-NeMo#12459)
* loss upscaling has been moved to MCore, no need to handle it in model level any more Signed-off-by: Xiaowei Ren <xren@nvidia.com> * return loss_sum and num_valid_tokens separately Signed-off-by: Xiaowei Ren <xren@nvidia.com> * change num_tokens dtype to int Signed-off-by: Xiaowei Ren <xren@nvidia.com> * fix a return type Signed-off-by: Xiaowei Ren <xren@nvidia.com> * clean masked_token_loss Signed-off-by: Xiaowei Ren <xren@nvidia.com> * Apply isort and black reformatting Signed-off-by: xrennvidia <xrennvidia@users.noreply.github.com> * minor fix Signed-off-by: Xiaowei Ren <xren@nvidia.com> * Apply isort and black reformatting Signed-off-by: xrennvidia <xrennvidia@users.noreply.github.com> * minor fix Signed-off-by: Xiaowei Ren <xren@nvidia.com> * Apply isort and black reformatting Signed-off-by: xrennvidia <xrennvidia@users.noreply.github.com> * bug fix Signed-off-by: Xiaowei Ren <xren@nvidia.com> * remove one unused import Signed-off-by: Xiaowei Ren <xren@nvidia.com> * fix pylint error Signed-off-by: Xiaowei Ren <xren@nvidia.com> * Apply isort and black reformatting Signed-off-by: xrennvidia <xrennvidia@users.noreply.github.com> --------- Signed-off-by: Xiaowei Ren <xren@nvidia.com> Signed-off-by: xrennvidia <xrennvidia@users.noreply.github.com> Co-authored-by: xrennvidia <xrennvidia@users.noreply.github.com>
1 parent 2f08584 commit 4609f97

2 files changed

Lines changed: 29 additions & 44 deletions

File tree

nemo/collections/speechlm/models/speech_to_text_llm_model.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from nemo.collections.llm import fn
4747
from nemo.collections.llm.gpt.model.base import (
4848
GPTConfig,
49-
GPTModel,
5049
get_batch_on_this_context_parallel_rank,
5150
get_packed_seq_params,
5251
)
@@ -883,7 +882,7 @@ def inference_step(self, batch, mode):
883882

884883
if isinstance(forward_output, tuple):
885884
# reduce validation loss
886-
loss = self.validation_loss_reduction.forward(batch=batch, forward_out=forward_output)[1]['avg']
885+
loss = self.validation_loss_reduction.forward(batch=batch, forward_out=forward_output)[-1]['avg']
887886
else:
888887
# no labels provided, use a dummy loss value
889888
loss = 0.0
@@ -915,8 +914,14 @@ def inference_step(self, batch, mode):
915914
labels_text = clean_end_string(labels_text, self.tokenizer, data_cfg.end_string)
916915

917916
if data_cfg.get("remove_text_pc", False):
918-
preds_text = [remove_punctuations(p.lower(), data_cfg.get("punctuations", None)) for p in preds_text]
919-
labels_text = [remove_punctuations(l.lower(), data_cfg.get("punctuations", None)) for l in labels_text]
917+
preds_text = [
918+
remove_punctuations(pred_text.lower(), data_cfg.get("punctuations", None))
919+
for pred_text in preds_text
920+
]
921+
labels_text = [
922+
remove_punctuations(label_text.lower(), data_cfg.get("punctuations", None))
923+
for label_text in labels_text
924+
]
920925

921926
if data_cfg.get("log_every_n_steps", None) is not None:
922927
if batch_idx % data_cfg.log_every_n_steps == 0:

nemo/lightning/megatron_parallel.py

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,7 +1742,7 @@ def __init__(self, validation_step: bool = False, val_drop_last: bool = True) ->
17421742

17431743
def forward(
17441744
self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor
1745-
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1745+
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
17461746
"""Taken from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 .""" # pylint: disable=line-too-long
17471747
from megatron.core import parallel_state
17481748

@@ -1752,33 +1752,30 @@ def forward(
17521752
if isinstance(forward_out, tuple):
17531753
forward_out, loss_mask = forward_out
17541754
batch["loss_mask"] = loss_mask
1755+
17551756
cp_size = parallel_state.get_context_parallel_world_size()
1756-
if cp_size == 1:
1757-
loss_for_ub = masked_token_loss(forward_out, batch["loss_mask"])
1757+
loss_sum_for_ub = masked_token_loss(forward_out, batch["loss_mask"], cp_size)
1758+
if cp_size == 1 or batch['num_valid_tokens_in_ub'] is None:
1759+
num_valid_tokens_in_ub = batch["loss_mask"].sum()
17581760
else:
1759-
loss_for_ub = masked_token_loss_context_parallel(
1760-
forward_out, batch["loss_mask"], batch['num_valid_tokens_in_ub']
1761-
)
1761+
num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub']
1762+
if num_valid_tokens_in_ub < 0.5: # no valid tokens
1763+
num_valid_tokens_in_ub += 1.0
1764+
num_valid_tokens_in_ub = num_valid_tokens_in_ub.clone().detach().to(torch.int)
17621765

17631766
if self.validation_step and not self.val_drop_last:
1764-
num_valid_tokens_in_ub = batch["loss_mask"].sum()
1765-
if loss_for_ub.isnan():
1766-
assert batch["loss_mask"].count_nonzero() == 0, "Got NaN loss with non-empty input"
1767+
if loss_sum_for_ub.isnan():
1768+
assert num_valid_tokens_in_ub == 0, "Got NaN loss with non-empty input"
17671769
loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub)
1768-
else:
1769-
loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub
17701770

17711771
loss_sum_and_ub_size_all_gpu = torch.cat(
1772-
[
1773-
loss_sum_for_ub.clone().detach().view(1),
1774-
torch.tensor([num_valid_tokens_in_ub], device=torch.cuda.current_device()).clone().detach(),
1775-
]
1772+
[loss_sum_for_ub.clone().detach().view(1), num_valid_tokens_in_ub]
17761773
)
17771774
torch.distributed.all_reduce(loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group())
1778-
return loss_for_ub * cp_size, {"loss_sum_and_ub_size": loss_sum_and_ub_size_all_gpu}
1775+
return loss_sum_for_ub, num_valid_tokens_in_ub, {"loss_sum_and_ub_size": loss_sum_and_ub_size_all_gpu}
17791776

1780-
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
1781-
return loss_for_ub * cp_size, {"avg": reduced_loss}
1777+
reduced_loss = average_losses_across_data_parallel_group([loss_sum_for_ub / num_valid_tokens_in_ub])
1778+
return loss_sum_for_ub, num_valid_tokens_in_ub, {"avg": reduced_loss}
17821779

17831780
def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor:
17841781
"""Taken from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 .""" # pylint: disable=line-too-long
@@ -1818,34 +1815,17 @@ def forward(
18181815
return super().forward(batch, forward_out)
18191816

18201817

1821-
def masked_token_loss(tensor: Tensor, mask: Tensor):
1818+
def masked_token_loss(tensor: Tensor, mask: Tensor, cp_size: int = 1):
18221819
"""
18231820
The function takes as input per-token loss and masks non-required values.
18241821
"""
18251822
losses = tensor.float()
18261823
loss_mask = mask.view(-1).float()
1827-
num_valid_tokens = loss_mask.sum()
1828-
if num_valid_tokens < 0.5: # no valid tokens
1829-
num_valid_tokens += 1.0
1830-
loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens # sequence level nll
1831-
1832-
return loss
1833-
1834-
1835-
def masked_token_loss_context_parallel(tensor: Tensor, mask: Tensor, num_valid_tokens_in_ub: int):
1836-
"""
1837-
masked token loss for CP > 1 as a separate function for readability.
1838-
"""
1839-
from megatron.core import parallel_state
1824+
loss = torch.sum(losses.view(-1) * loss_mask) # sequence level nll
1825+
if cp_size > 1:
1826+
from megatron.core import parallel_state
18401827

1841-
losses = tensor.float()
1842-
loss_mask = mask.view(-1).float()
1843-
if num_valid_tokens_in_ub is None:
1844-
num_valid_tokens_in_ub = loss_mask.sum()
1845-
if num_valid_tokens_in_ub < 0.5: # no valid tokens
1846-
num_valid_tokens_in_ub += 1.0
1847-
loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll
1848-
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
1828+
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
18491829

18501830
return loss
18511831

0 commit comments

Comments
 (0)