Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,6 @@ checkpoints/

# Hydra outputs
outputs/

# Nsys profiles
*.nsys-rep
110 changes: 98 additions & 12 deletions bionemo-recipes/models/esm2/src/esm/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
"""

import logging
import threading
from dataclasses import dataclass, field
from typing import Any, TypedDict

import datasets
import nvtx
import torch
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp
from transformers import DataCollator, DataCollatorForLanguageModeling
Expand Down Expand Up @@ -234,20 +236,50 @@ class TokenPackingDataset(torch.utils.data.IterableDataset):
"""Whether to drop the last batch if it's less than max_length."""
split_samples: bool = False
"""Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens."""
pad_sequences_to_be_divisible_by: int | None = None
"""If set, account for per-sequence padding when accumulating batches.

Each sequence's contribution to the batch length is rounded up to the nearest multiple of this value,
matching the padding behavior of DataCollatorWithFlattening with the same parameter. When used with
split_samples=True, the split point is chosen so that the first part (after padding) exactly fills
the remaining batch capacity.
"""

def __post_init__(self):
"""Validate padding configuration."""
if (
self.pad_sequences_to_be_divisible_by is not None
and self.max_tokens_per_batch % self.pad_sequences_to_be_divisible_by != 0
):
logger.warning(
"max_tokens_per_batch (%d) is not divisible by pad_sequences_to_be_divisible_by (%d). "
"Batches may not fill to exactly max_tokens_per_batch when split_samples=True.",
self.max_tokens_per_batch,
self.pad_sequences_to_be_divisible_by,
)

def _padded_len(self, length: int) -> int:
"""Return the padded length of a sequence, rounding up to the nearest multiple of pad_sequences_to_be_divisible_by."""
if self.pad_sequences_to_be_divisible_by is None:
return length
return -(-length // self.pad_sequences_to_be_divisible_by) * self.pad_sequences_to_be_divisible_by

def __iter__(self):
"""Yield batches of samples, each with a variable number of tokens up to the maximum length.

When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting
the final sample if needed. The remaining tokens from the split sample start the next batch.

When pad_sequences_to_be_divisible_by is set, each sequence's padded length is used when
accumulating batch sizes, so the total padded length of the batch matches max_tokens_per_batch.

Returns:
A generator of batches of samples, each with a variable number of tokens up to the maximum length.
"""
samples = []
current_length = 0
for sample in iter(self.dataset):
current_length += len(sample["input_ids"])
current_length += self._padded_len(len(sample["input_ids"]))
if current_length == self.max_tokens_per_batch:
yield [*samples, sample]
samples = []
Expand All @@ -261,15 +293,19 @@ def __iter__(self):
samples = [sample]

else:
# Calculate how many tokens are already in the batch
tokens_in_batch = current_length - len(sample["input_ids"])
# Calculate how many tokens we can fit from this sample
# Calculate how many padded tokens are already in the batch
tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"]))
# Calculate how many tokens we can fit from this sample, ensuring the
# padded length doesn't exceed the remaining capacity.
tokens_available = self.max_tokens_per_batch - tokens_in_batch
if self.pad_sequences_to_be_divisible_by is not None:
d = self.pad_sequences_to_be_divisible_by
tokens_available = (tokens_available // d) * d
first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available)
yield [*samples, first_part]
samples = [remaining_part]

current_length = len(samples[0]["input_ids"])
current_length = self._padded_len(len(samples[0]["input_ids"]))
else:
samples.append(sample)

Expand Down Expand Up @@ -349,6 +385,9 @@ def __call__(self, features) -> list[dict[str, Any]]:
"""
batch = self.collator(features)

# Remove the attention mask from the batch, it's not valid for CP.
batch.pop("attention_mask", None)

if self.is_causal_lm:
labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100)
batch["labels"] = labels[..., 1:].contiguous()
Expand Down Expand Up @@ -376,12 +415,10 @@ def __call__(self, features) -> list[dict[str, Any]]:
max_length = seqlens_q.max().item()
elif self.qkv_format == "bshd":
max_length = batch["input_ids"].shape[1]
# For BSHD context parallelism, we can't handle padding, so we remove the attention mask.
del batch_shard["attention_mask"]
else:
raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!")

batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64)
batch_shard["max_length_k"] = batch_shard["max_length_q"] = ((max_length + 63) // 64) * 64
combined_batch.append(batch_shard)

if self.tp_world_size is not None:
Expand Down Expand Up @@ -431,6 +468,9 @@ def __init__(
self.cp_tp_group = cp_tp_mesh.get_group()
self.num_cp_tp_ranks = cp_tp_mesh.size()
self._iterator = None
self._prefetch_thread: threading.Thread | None = None
self._prefetch_result: Any = None
self._cuda_device: int | None = None

logger.debug(
"Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s",
Expand All @@ -442,13 +482,56 @@ def __iter__(self):
"""Make the dataloader iterable."""
if self.cp_tp_rank == 0:
self._iterator = iter(self.dataloader) # < --- collator output.
self.close()
# Capture CUDA device from main thread; torch.cuda.set_device is per-thread,
# so the background thread needs to set it explicitly.
self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None
self._kick_prefetch()
return self

@nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue")
def __next__(self):
"""Get the batch from the dataloader for the current CP rank."""
batch = self._send_data_to_cp_tp_ranks()
return batch

if self._prefetch_thread is not None:
self._prefetch_thread.join()
result = self._prefetch_result
if isinstance(result, StopIteration):
self._prefetch_thread = None
raise result
if isinstance(result, Exception):
self._prefetch_thread = None
raise result
self._kick_prefetch()
return result

def _kick_prefetch(self):
"""Start a background thread to prefetch exactly one batch via scatter."""
self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True)
self._prefetch_thread.start()

def _do_one_prefetch(self):
"""Fetch one batch in the background.

This function calls the _send_data_to_cp_tp_ranks function to materialize the next batches for all ranks in the
given CP/TP group, and uses torch.distributed.scatter_object_list to scatter these batches to their
corresponding ranks. The result is stored in _prefetch_result, and returned when __next__ is called.
"""
if self._cuda_device is not None:
torch.cuda.set_device(self._cuda_device)
try:
self._prefetch_result = self._send_data_to_cp_tp_ranks()
except StopIteration as e:
self._prefetch_result = e
except Exception as e:
self._prefetch_result = e

def close(self):
"""Stop the prefetch thread. Must be called before destroy_process_group()."""
if self._prefetch_thread is not None:
self._prefetch_thread.join(timeout=10)
self._prefetch_thread = None

@nvtx.annotate("ContextParallelDataLoaderWrapper _send_data_to_cp_tp_ranks", color="green")
def _send_data_to_cp_tp_ranks(self):
"""Send data to all the CP/TP ranks.

Expand Down Expand Up @@ -476,7 +559,8 @@ def _send_data_to_cp_tp_ranks(self):

"""
try:
combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None
with nvtx.annotate("ContextParallelDataLoaderWrapper next batch", color="green"):
combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None
except StopIteration as ex:
# If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so
# that the dataloader can be restarted.
Expand Down Expand Up @@ -679,6 +763,7 @@ def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token

# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387
# we can replace this with the one in TransformerEngine.
@nvtx.annotate("collator._split_batch_by_cp_rank", color="green")
def _split_batch_by_cp_rank(
cu_seqlens_padded: torch.Tensor | None,
input_ids_padded: torch.Tensor,
Expand Down Expand Up @@ -852,6 +937,7 @@ class BatchType(TypedDict):
pad_between_seqs: bool


@nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green")
def _scatter_batch_to_cp_tp_ranks(
all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None
) -> BatchType | StopIteration:
Expand Down
Loading