From c798c5cb4daff9a41bed7c0115f0fc70a54fbd7b Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 11 Feb 2026 08:53:00 -0800 Subject: [PATCH 1/9] llama3 cp performance fixes Signed-off-by: Peter St. John --- .gitignore | 3 + .../models/esm2/src/esm/collator.py | 94 +++++- .../models/esm2/tests/test_collator.py | 213 +++++++++++++ .../tests/test_collator_context_parallel.py | 148 ++++++++- bionemo-recipes/models/llama3/collator.py | 94 +++++- .../recipes/esm2_native_te/collator.py | 94 +++++- .../recipes/esm2_native_te/perf_logger.py | 85 +++--- .../recipes/esm2_native_te/train_fsdp2.py | 2 +- .../recipes/llama3_native_te/Dockerfile | 6 +- .../recipes/llama3_native_te/README.md | 60 ++++ .../recipes/llama3_native_te/collator.py | 94 +++++- .../hydra_config/L2_lingua_1b.yaml | 7 +- .../hydra_config/defaults.yaml | 7 +- .../recipes/llama3_native_te/perf_logger.py | 287 ++++++++++++------ .../llama3_native_te/tests/test_dataset.py | 3 +- .../tests/test_distributed_checkpointing.py | 8 + .../tests/test_perf_logger.py | 210 +++++++++++++ .../llama3_native_te/tests/test_train.py | 73 +++++ .../tests/test_train_two_gpu.py | 57 +++- .../recipes/llama3_native_te/train_ddp.py | 2 +- .../recipes/llama3_native_te/train_fsdp2.py | 4 +- .../llama3_native_te/train_fsdp2_cp.py | 24 +- 22 files changed, 1347 insertions(+), 228 deletions(-) create mode 100644 bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py diff --git a/.gitignore b/.gitignore index 93d8daa877..06e782185b 100644 --- a/.gitignore +++ b/.gitignore @@ -210,3 +210,6 @@ checkpoints/ # Hydra outputs outputs/ + +# Nsys profiles +*.nsys-rep diff --git a/bionemo-recipes/models/esm2/src/esm/collator.py b/bionemo-recipes/models/esm2/src/esm/collator.py index 43158d3f9b..e3cf6a9607 100644 --- a/bionemo-recipes/models/esm2/src/esm/collator.py +++ b/bionemo-recipes/models/esm2/src/esm/collator.py @@ -19,10 +19,12 @@ """ import logging +import threading from dataclasses import dataclass, field from typing import Any, TypedDict import datasets +import nvtx import torch from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp from transformers import DataCollator, DataCollatorForLanguageModeling @@ -234,6 +236,33 @@ class TokenPackingDataset(torch.utils.data.IterableDataset): """Whether to drop the last batch if it's less than max_length.""" split_samples: bool = False """Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens.""" + pad_sequences_to_be_divisible_by: int | None = None + """If set, account for per-sequence padding when accumulating batches. + + Each sequence's contribution to the batch length is rounded up to the nearest multiple of this value, + matching the padding behavior of DataCollatorWithFlattening with the same parameter. When used with + split_samples=True, the split point is chosen so that the first part (after padding) exactly fills + the remaining batch capacity. + """ + + def __post_init__(self): + """Validate padding configuration.""" + if ( + self.pad_sequences_to_be_divisible_by is not None + and self.max_tokens_per_batch % self.pad_sequences_to_be_divisible_by != 0 + ): + logger.warning( + "max_tokens_per_batch (%d) is not divisible by pad_sequences_to_be_divisible_by (%d). " + "Batches may not fill to exactly max_tokens_per_batch when split_samples=True.", + self.max_tokens_per_batch, + self.pad_sequences_to_be_divisible_by, + ) + + def _padded_len(self, length: int) -> int: + """Return the padded length of a sequence, rounding up to the nearest multiple of pad_sequences_to_be_divisible_by.""" + if self.pad_sequences_to_be_divisible_by is None: + return length + return -(-length // self.pad_sequences_to_be_divisible_by) * self.pad_sequences_to_be_divisible_by def __iter__(self): """Yield batches of samples, each with a variable number of tokens up to the maximum length. @@ -241,13 +270,16 @@ def __iter__(self): When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting the final sample if needed. The remaining tokens from the split sample start the next batch. + When pad_sequences_to_be_divisible_by is set, each sequence's padded length is used when + accumulating batch sizes, so the total padded length of the batch matches max_tokens_per_batch. + Returns: A generator of batches of samples, each with a variable number of tokens up to the maximum length. """ samples = [] current_length = 0 for sample in iter(self.dataset): - current_length += len(sample["input_ids"]) + current_length += self._padded_len(len(sample["input_ids"])) if current_length == self.max_tokens_per_batch: yield [*samples, sample] samples = [] @@ -261,15 +293,15 @@ def __iter__(self): samples = [sample] else: - # Calculate how many tokens are already in the batch - tokens_in_batch = current_length - len(sample["input_ids"]) + # Calculate how many padded tokens are already in the batch + tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"])) # Calculate how many tokens we can fit from this sample tokens_available = self.max_tokens_per_batch - tokens_in_batch first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) yield [*samples, first_part] samples = [remaining_part] - current_length = len(samples[0]["input_ids"]) + current_length = self._padded_len(len(samples[0]["input_ids"])) else: samples.append(sample) @@ -349,6 +381,9 @@ def __call__(self, features) -> list[dict[str, Any]]: """ batch = self.collator(features) + # Remove the attention mask from the batch, it's not valid for CP. + batch.pop("attention_mask", None) + if self.is_causal_lm: labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) batch["labels"] = labels[..., 1:].contiguous() @@ -376,12 +411,10 @@ def __call__(self, features) -> list[dict[str, Any]]: max_length = seqlens_q.max().item() elif self.qkv_format == "bshd": max_length = batch["input_ids"].shape[1] - # For BSHD context parallelism, we can't handle padding, so we remove the attention mask. - del batch_shard["attention_mask"] else: raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") - batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64) + batch_shard["max_length_k"] = batch_shard["max_length_q"] = ((max_length + 63) // 64) * 64 combined_batch.append(batch_shard) if self.tp_world_size is not None: @@ -431,6 +464,9 @@ def __init__( self.cp_tp_group = cp_tp_mesh.get_group() self.num_cp_tp_ranks = cp_tp_mesh.size() self._iterator = None + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Any = None + self._cuda_device: int | None = None logger.debug( "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", @@ -442,13 +478,46 @@ def __iter__(self): """Make the dataloader iterable.""" if self.cp_tp_rank == 0: self._iterator = iter(self.dataloader) # < --- collator output. + self.close() + # Capture CUDA device from main thread; torch.cuda.set_device is per-thread, + # so the background thread needs to set it explicitly. + self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None + self._kick_prefetch() return self + @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") def __next__(self): """Get the batch from the dataloader for the current CP rank.""" - batch = self._send_data_to_cp_tp_ranks() - return batch - + self._prefetch_thread.join() + result = self._prefetch_result + if isinstance(result, StopIteration): + self._prefetch_thread = None + raise result + self._kick_prefetch() + return result + + def _kick_prefetch(self): + """Start a background thread to prefetch exactly one batch via scatter.""" + self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True) + self._prefetch_thread.start() + + def _do_one_prefetch(self): + """Fetch one batch in the background. Stores result in _prefetch_result.""" + if self._cuda_device is not None: + torch.cuda.set_device(self._cuda_device) + try: + self._prefetch_result = self._send_data_to_cp_tp_ranks() + except Exception: + # Process group may have been destroyed; signal stop. + self._prefetch_result = StopIteration() + + def close(self): + """Stop the prefetch thread. Must be called before destroy_process_group().""" + if self._prefetch_thread is not None: + self._prefetch_thread.join(timeout=10) + self._prefetch_thread = None + + @nvtx.annotate("ContextParallelDataLoaderWrapper _send_data_to_cp_tp_ranks", color="green") def _send_data_to_cp_tp_ranks(self): """Send data to all the CP/TP ranks. @@ -476,7 +545,8 @@ def _send_data_to_cp_tp_ranks(self): """ try: - combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + with nvtx.annotate("ContextParallelDataLoaderWrapper next batch", color="green"): + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None except StopIteration as ex: # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so # that the dataloader can be restarted. @@ -679,6 +749,7 @@ def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token # TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 # we can replace this with the one in TransformerEngine. +@nvtx.annotate("collator._split_batch_by_cp_rank", color="green") def _split_batch_by_cp_rank( cu_seqlens_padded: torch.Tensor | None, input_ids_padded: torch.Tensor, @@ -852,6 +923,7 @@ class BatchType(TypedDict): pad_between_seqs: bool +@nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green") def _scatter_batch_to_cp_tp_ranks( all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None ) -> BatchType | StopIteration: diff --git a/bionemo-recipes/models/esm2/tests/test_collator.py b/bionemo-recipes/models/esm2/tests/test_collator.py index 4aade1b3b2..f60c137fa9 100644 --- a/bionemo-recipes/models/esm2/tests/test_collator.py +++ b/bionemo-recipes/models/esm2/tests/test_collator.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from unittest.mock import MagicMock import pytest @@ -795,3 +796,215 @@ def __iter__(self): assert "labels" in sample assert len(sample["input_ids"]) == len(sample["attention_mask"]) assert len(sample["input_ids"]) == len(sample["labels"]) + + +def test_token_packing_dataset_pad_sequences_to_be_divisible_by_warning(caplog): + """Test that a warning is issued when max_tokens_per_batch is not divisible by pad_sequences_to_be_divisible_by.""" + + class MockDataset(torch.utils.data.IterableDataset): + def __iter__(self): + yield {"input_ids": torch.arange(10)} + + with caplog.at_level(logging.WARNING): + TokenPackingDataset( + MockDataset(), + max_tokens_per_batch=100, + pad_sequences_to_be_divisible_by=7, + ) + + assert "not divisible" in caplog.text + + +def test_token_packing_dataset_pad_sequences_to_be_divisible_by_no_warning(caplog): + """Test that no warning is issued when max_tokens_per_batch is divisible by pad_sequences_to_be_divisible_by.""" + + class MockDataset(torch.utils.data.IterableDataset): + def __iter__(self): + yield {"input_ids": torch.arange(10)} + + with caplog.at_level(logging.WARNING): + TokenPackingDataset( + MockDataset(), + max_tokens_per_batch=100, + pad_sequences_to_be_divisible_by=4, + ) + + assert "not divisible" not in caplog.text + + +def test_token_packing_dataset_with_padding_accounts_for_padded_lengths(): + """Test that TokenPackingDataset accounts for padded lengths when pad_sequences_to_be_divisible_by is set.""" + + class MockDataset(torch.utils.data.IterableDataset): + def __iter__(self): + yield {"input_ids": list(range(5))} # padded to 8 + yield {"input_ids": list(range(3))} # padded to 4 + yield {"input_ids": list(range(7))} # padded to 8 + yield {"input_ids": list(range(6))} # padded to 8 + + # Without padding: 5+3+7 = 15 <= 20, 5+3+7+6 = 21 > 20 + # With padding (P=4): padded(5)=8, padded(3)=4, 8+4=12, +padded(7)=8 -> 20 == max + dataset = MockDataset() + token_packing_dataset = TokenPackingDataset( + dataset, + max_tokens_per_batch=20, + pad_sequences_to_be_divisible_by=4, + drop_last=False, + ) + batches = list(token_packing_dataset) + + # First batch: [5, 3, 7] -> padded: 8+4+8 = 20 == max + assert len(batches) == 2 + assert len(batches[0]) == 3 + assert [len(s["input_ids"]) for s in batches[0]] == [5, 3, 7] + # Second batch: [6] -> padded: 8 + assert len(batches[1]) == 1 + assert [len(s["input_ids"]) for s in batches[1]] == [6] + + +def test_token_packing_dataset_with_padding_and_split_samples(): + """Test TokenPackingDataset with split_samples=True and pad_sequences_to_be_divisible_by.""" + + class MockDataset(torch.utils.data.IterableDataset): + def __iter__(self): + yield {"input_ids": list(range(5))} # padded to 8 + yield {"input_ids": list(range(3))} # padded to 4 + yield {"input_ids": list(range(15))} # padded to 16, exceeds remaining (24-12=12) + + # P=4, max=24 + # Batch 1: padded(5)=8, padded(3)=4 -> 12 so far. Next: padded(15)=16 -> 12+16=28 > 24 + # tokens_available = 24 - 12 = 12. Split at 12: first_part=12 tokens, remaining=3 tokens + # Batch 1: [5, 3, 12] -> padded: 8 + 4 + 12 = 24 == max + # Batch 2: [3] -> padded: 4 + dataset = MockDataset() + token_packing_dataset = TokenPackingDataset( + dataset, + max_tokens_per_batch=24, + split_samples=True, + pad_sequences_to_be_divisible_by=4, + drop_last=False, + ) + batches = list(token_packing_dataset) + + assert len(batches) == 2 + assert [len(s["input_ids"]) for s in batches[0]] == [5, 3, 12] + assert [len(s["input_ids"]) for s in batches[1]] == [3] + + +def test_token_packing_dataset_with_padding_split_fills_exactly_max(tokenizer): + """Test that split_samples + pad_sequences_to_be_divisible_by produces batches that collate to exactly max_tokens.""" + pad_divisor = 4 + max_tokens = 24 + + class MockDataset(torch.utils.data.IterableDataset): + def __iter__(self): + # Generate many sequences of varying lengths + for length in [7, 5, 10, 3, 6, 9, 11, 4, 8, 13, 2, 14, 7, 5, 10]: + yield {"input_ids": list(range(length))} + + dataset = MockDataset() + token_packing_dataset = TokenPackingDataset( + dataset, + max_tokens_per_batch=max_tokens, + split_samples=True, + pad_sequences_to_be_divisible_by=pad_divisor, + drop_last=True, + ) + + mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.0) + collator = DataCollatorWithFlattening( + collator=mlm_collator, + pad_sequences_to_be_divisible_by=pad_divisor, + ) + + batches = list(token_packing_dataset) + assert len(batches) > 0, "Should produce at least one batch" + + for i, batch_samples in enumerate(batches): + collated = collator(batch_samples) + total_tokens = collated["input_ids"].numel() + assert total_tokens == max_tokens, ( + f"Batch {i}: expected exactly {max_tokens} tokens after collation, got {total_tokens}. " + f"Sample lengths: {[len(s['input_ids']) for s in batch_samples]}" + ) + + +def test_token_packing_dataset_with_padding_split_random_sequences(tokenizer): + """Test with random sequence lengths that split_samples + padding always produces exact-sized batches.""" + pad_divisor = 8 + max_tokens = 64 + + class MockDataset(torch.utils.data.IterableDataset): + def __iter__(self): + torch.manual_seed(42) + for _ in range(100): + length = torch.randint(1, 30, (1,)).item() + yield {"input_ids": list(range(length))} + + dataset = MockDataset() + token_packing_dataset = TokenPackingDataset( + dataset, + max_tokens_per_batch=max_tokens, + split_samples=True, + pad_sequences_to_be_divisible_by=pad_divisor, + drop_last=True, + ) + + mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.0) + collator = DataCollatorWithFlattening( + collator=mlm_collator, + pad_sequences_to_be_divisible_by=pad_divisor, + ) + + batches = list(token_packing_dataset) + assert len(batches) > 0, "Should produce at least one batch" + + for i, batch_samples in enumerate(batches): + collated = collator(batch_samples) + total_tokens = collated["input_ids"].numel() + assert total_tokens == max_tokens, ( + f"Batch {i}: expected exactly {max_tokens} tokens after collation, got {total_tokens}. " + f"Sample lengths: {[len(s['input_ids']) for s in batch_samples]}" + ) + + +def test_token_packing_dataset_with_padding_split_drop_last_false(tokenizer): + """Test that with drop_last=False, all batches except the last have exactly max_tokens.""" + pad_divisor = 4 + max_tokens = 16 + + class MockDataset(torch.utils.data.IterableDataset): + def __iter__(self): + for length in [5, 7, 3, 9, 6, 4]: + yield {"input_ids": list(range(length))} + + dataset = MockDataset() + token_packing_dataset = TokenPackingDataset( + dataset, + max_tokens_per_batch=max_tokens, + split_samples=True, + pad_sequences_to_be_divisible_by=pad_divisor, + drop_last=False, + ) + + mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.0) + collator = DataCollatorWithFlattening( + collator=mlm_collator, + pad_sequences_to_be_divisible_by=pad_divisor, + ) + + batches = list(token_packing_dataset) + assert len(batches) >= 2, "Should produce at least two batches" + + # All batches except the last must be exactly max_tokens + for i, batch_samples in enumerate(batches[:-1]): + collated = collator(batch_samples) + total_tokens = collated["input_ids"].numel() + assert total_tokens == max_tokens, ( + f"Batch {i}: expected exactly {max_tokens} tokens after collation, got {total_tokens}. " + f"Sample lengths: {[len(s['input_ids']) for s in batch_samples]}" + ) + + # Last batch can be <= max_tokens + last_collated = collator(batches[-1]) + assert last_collated["input_ids"].numel() <= max_tokens diff --git a/bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py b/bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py index 295ccc5d75..fe8991b42d 100644 --- a/bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py +++ b/bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import threading from typing import Dict, Iterator, List from unittest import mock @@ -398,7 +399,7 @@ def run_roundtrip(base_batch): loader_rank1 = ContextParallelDataLoaderWrapper(None, cp_mesh_rank1) scatter_payload: Dict[str, List[Dict[str, torch.Tensor]]] = {} - current_rank = {"value": None} + data_ready = threading.Event() def fake_scatter( *, @@ -408,9 +409,14 @@ def fake_scatter( group_src, ): if scatter_object_input_list is not None: + # Rank 0: store the full payload and return shard 0 scatter_payload["data"] = scatter_object_input_list - assert "data" in scatter_payload, "Rank 0 payload missing" - scatter_object_output_list[0] = scatter_payload["data"][current_rank["value"]] + data_ready.set() + scatter_object_output_list[0] = scatter_object_input_list[0] + else: + # Rank 1: wait for rank 0's data, then return shard 1 + data_ready.wait(timeout=5) + scatter_object_output_list[0] = scatter_payload["data"][1] with ( mock.patch("esm.collator.torch.distributed.scatter_object_list", side_effect=fake_scatter), @@ -419,10 +425,7 @@ def fake_scatter( iter(loader_rank0) iter(loader_rank1) - current_rank["value"] = 0 batch_cp0 = next(loader_rank0) - - current_rank["value"] = 1 batch_cp1 = next(loader_rank1) return batch_cp0, batch_cp1 @@ -487,7 +490,7 @@ def run_roundtrip(base_batch): loader_rank1 = ContextParallelDataLoaderWrapper(None, cp_mesh_rank1) scatter_payload: Dict[str, List[Dict[str, torch.Tensor]]] = {} - current_rank = {"value": None} + data_ready = threading.Event() def fake_scatter( *, @@ -497,9 +500,14 @@ def fake_scatter( group_src, ): if scatter_object_input_list is not None: + # Rank 0: store the full payload and return shard 0 scatter_payload["data"] = scatter_object_input_list - assert "data" in scatter_payload, "Rank 0 payload missing" - scatter_object_output_list[0] = scatter_payload["data"][current_rank["value"]] + data_ready.set() + scatter_object_output_list[0] = scatter_object_input_list[0] + else: + # Rank 1: wait for rank 0's data, then return shard 1 + data_ready.wait(timeout=5) + scatter_object_output_list[0] = scatter_payload["data"][1] with ( mock.patch("esm.collator.torch.distributed.scatter_object_list", side_effect=fake_scatter), @@ -508,10 +516,7 @@ def fake_scatter( iter(loader_rank0) iter(loader_rank1) - current_rank["value"] = 0 batch_cp0 = next(loader_rank0) - - current_rank["value"] = 1 batch_cp1 = next(loader_rank1) return batch_cp0, batch_cp1 @@ -940,8 +945,8 @@ def test_data_collator_for_context_parallel_returns_correct_list_size(tokenizer, # Create test sequences features = [ - {"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens - {"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2]}, # 9 tokens + {"input_ids": [0, 5, 6, 7, 8, 9, 10, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1]}, # 8 tokens + {"input_ids": [0, 11, 12, 13, 14, 15, 16, 17, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1]}, # 9 tokens ] # Call the collator @@ -1044,6 +1049,121 @@ def test_data_collator_for_context_parallel_thd_causal_lm(tokenizer): torch.testing.assert_close(result[1]["shift_labels"], expected_rank1_shift_labels) +def test_data_collator_for_context_parallel_thd_correctness(tokenizer): + """Test that DataCollatorForContextParallel returns correct values for THD format. + + This test verifies: + 1. max_length_q and max_length_k are correctly rounded up to a multiple of 64 + 2. input_ids and labels have the correct shape (sharded by cp_world_size) + 3. cu_seq_lens_* tensors are preserved correctly + 4. All shards together reconstruct the original data + """ + cp_world_size = 2 + divisibility_factor = 2 * cp_world_size + + # Create the wrapped collator that produces padded THD batches - disable MLM for deterministic testing + base_collator = DataCollatorWithFlattening( + collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), + pad_sequences_to_be_divisible_by=divisibility_factor, + ) + + # Create the context parallel collator + cp_collator = DataCollatorForContextParallel(collator=base_collator, cp_world_size=cp_world_size, qkv_format="thd") + + # Create test sequences - 8 tokens each for easy division + features = [ + {"input_ids": [0, 5, 6, 7, 8, 9, 10, 2]}, # 8 tokens + {"input_ids": [0, 11, 12, 13, 14, 15, 16, 2]}, # 8 tokens + ] + + # Call the collator + result = cp_collator(features) + + assert len(result) == cp_world_size + + # Verify max_length_q and max_length_k are rounded up to a multiple of 64 + for cp_rank, shard in enumerate(result): + max_length = shard["max_length_q"] + assert max_length == shard["max_length_k"], "max_length_q and max_length_k should be equal" + assert max_length % 64 == 0, f"CP rank {cp_rank}: max_length {max_length} should be a multiple of 64" + # Since our sequences are 8 tokens, padded to divisibility_factor=4, max_seqlen should be 8, + # and rounded up to 64 + assert max_length == 64, f"CP rank {cp_rank}: expected max_length=64, got {max_length}" + + # Verify input_ids shape - should be sharded along sequence dimension + # Original total tokens: 16 (8+8), each shard should have 16 / cp_world_size = 8 tokens + for cp_rank, shard in enumerate(result): + assert shard["input_ids"].shape[1] == 8, ( + f"CP rank {cp_rank}: expected input_ids shape [1, 8], got {shard['input_ids'].shape}" + ) + assert shard["labels"].shape[1] == 8, ( + f"CP rank {cp_rank}: expected labels shape [1, 8], got {shard['labels'].shape}" + ) + + # Verify that all shards together contain all the original tokens + all_input_ids = torch.cat([shard["input_ids"] for shard in result], dim=1) + # Check that all original tokens are present (sorted comparison since order may differ due to sharding) + expected_tokens = torch.tensor([[0, 5, 6, 7, 8, 9, 10, 2, 0, 11, 12, 13, 14, 15, 16, 2]], dtype=torch.int64) + torch.testing.assert_close( + torch.sort(all_input_ids.flatten())[0], + torch.sort(expected_tokens.flatten())[0], + msg="Sharded tokens don't match original tokens", + ) + + # Verify cu_seq_lens_q_padded and cu_seq_lens_k_padded are preserved in each shard + for cp_rank, shard in enumerate(result): + assert "cu_seq_lens_q_padded" in shard + assert "cu_seq_lens_k_padded" in shard + torch.testing.assert_close(shard["cu_seq_lens_q_padded"], shard["cu_seq_lens_k_padded"]) + + # Verify pad_between_seqs is True for THD format + for cp_rank, shard in enumerate(result): + assert shard["pad_between_seqs"] is True, f"CP rank {cp_rank}: pad_between_seqs should be True for THD format" + + +@pytest.mark.parametrize( + "max_seqlen,expected_rounded", + [ + (8, 64), # Small value rounds up to 64 + (64, 64), # Exactly 64 stays 64 + (65, 128), # Just over 64 rounds up to 128 + (100, 128), # 100 rounds up to 128 + (128, 128), # Exactly 128 stays 128 + (129, 192), # Just over 128 rounds up to 192 + ], +) +def test_data_collator_for_context_parallel_thd_max_length_rounding(tokenizer, max_seqlen, expected_rounded): + """Test that max_length_q/k is correctly rounded up to a multiple of 64 for various sequence lengths.""" + cp_world_size = 2 + divisibility_factor = 2 * cp_world_size + + # Create input_ids of the specified length (must be divisible by divisibility_factor) + # We pad to the next multiple of divisibility_factor if needed + padded_len = ((max_seqlen + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + input_ids = [0, *list(range(5, 5 + padded_len - 2)), 2] # [CLS] + tokens + [SEP] + input_ids = input_ids[:padded_len] # Truncate to exact length + + # Create the collators + base_collator = DataCollatorWithFlattening( + collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), + pad_sequences_to_be_divisible_by=divisibility_factor, + ) + cp_collator = DataCollatorForContextParallel(collator=base_collator, cp_world_size=cp_world_size, qkv_format="thd") + + # Use a single sequence to ensure max_seqlen is exactly what we expect after padding + features = [{"input_ids": input_ids}] + result = cp_collator(features) + + # The actual max_seqlen after padding may differ, but it should still round correctly + for shard in result: + max_length = shard["max_length_q"] + assert max_length % 64 == 0, f"max_length {max_length} should be a multiple of 64" + # Verify the rounding formula: ((x + 63) // 64) * 64 + actual_seqlen = (shard["cu_seq_lens_q_padded"][1:] - shard["cu_seq_lens_q_padded"][:-1]).max().item() + expected = ((actual_seqlen + 63) // 64) * 64 + assert max_length == expected, f"Expected max_length={expected} for seqlen={actual_seqlen}, got {max_length}" + + def test_data_collator_for_context_parallel_bshd(tokenizer): """Test that each shard from DataCollatorForContextParallel has all required keys from BatchType.""" diff --git a/bionemo-recipes/models/llama3/collator.py b/bionemo-recipes/models/llama3/collator.py index 43158d3f9b..e3cf6a9607 100644 --- a/bionemo-recipes/models/llama3/collator.py +++ b/bionemo-recipes/models/llama3/collator.py @@ -19,10 +19,12 @@ """ import logging +import threading from dataclasses import dataclass, field from typing import Any, TypedDict import datasets +import nvtx import torch from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp from transformers import DataCollator, DataCollatorForLanguageModeling @@ -234,6 +236,33 @@ class TokenPackingDataset(torch.utils.data.IterableDataset): """Whether to drop the last batch if it's less than max_length.""" split_samples: bool = False """Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens.""" + pad_sequences_to_be_divisible_by: int | None = None + """If set, account for per-sequence padding when accumulating batches. + + Each sequence's contribution to the batch length is rounded up to the nearest multiple of this value, + matching the padding behavior of DataCollatorWithFlattening with the same parameter. When used with + split_samples=True, the split point is chosen so that the first part (after padding) exactly fills + the remaining batch capacity. + """ + + def __post_init__(self): + """Validate padding configuration.""" + if ( + self.pad_sequences_to_be_divisible_by is not None + and self.max_tokens_per_batch % self.pad_sequences_to_be_divisible_by != 0 + ): + logger.warning( + "max_tokens_per_batch (%d) is not divisible by pad_sequences_to_be_divisible_by (%d). " + "Batches may not fill to exactly max_tokens_per_batch when split_samples=True.", + self.max_tokens_per_batch, + self.pad_sequences_to_be_divisible_by, + ) + + def _padded_len(self, length: int) -> int: + """Return the padded length of a sequence, rounding up to the nearest multiple of pad_sequences_to_be_divisible_by.""" + if self.pad_sequences_to_be_divisible_by is None: + return length + return -(-length // self.pad_sequences_to_be_divisible_by) * self.pad_sequences_to_be_divisible_by def __iter__(self): """Yield batches of samples, each with a variable number of tokens up to the maximum length. @@ -241,13 +270,16 @@ def __iter__(self): When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting the final sample if needed. The remaining tokens from the split sample start the next batch. + When pad_sequences_to_be_divisible_by is set, each sequence's padded length is used when + accumulating batch sizes, so the total padded length of the batch matches max_tokens_per_batch. + Returns: A generator of batches of samples, each with a variable number of tokens up to the maximum length. """ samples = [] current_length = 0 for sample in iter(self.dataset): - current_length += len(sample["input_ids"]) + current_length += self._padded_len(len(sample["input_ids"])) if current_length == self.max_tokens_per_batch: yield [*samples, sample] samples = [] @@ -261,15 +293,15 @@ def __iter__(self): samples = [sample] else: - # Calculate how many tokens are already in the batch - tokens_in_batch = current_length - len(sample["input_ids"]) + # Calculate how many padded tokens are already in the batch + tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"])) # Calculate how many tokens we can fit from this sample tokens_available = self.max_tokens_per_batch - tokens_in_batch first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) yield [*samples, first_part] samples = [remaining_part] - current_length = len(samples[0]["input_ids"]) + current_length = self._padded_len(len(samples[0]["input_ids"])) else: samples.append(sample) @@ -349,6 +381,9 @@ def __call__(self, features) -> list[dict[str, Any]]: """ batch = self.collator(features) + # Remove the attention mask from the batch, it's not valid for CP. + batch.pop("attention_mask", None) + if self.is_causal_lm: labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) batch["labels"] = labels[..., 1:].contiguous() @@ -376,12 +411,10 @@ def __call__(self, features) -> list[dict[str, Any]]: max_length = seqlens_q.max().item() elif self.qkv_format == "bshd": max_length = batch["input_ids"].shape[1] - # For BSHD context parallelism, we can't handle padding, so we remove the attention mask. - del batch_shard["attention_mask"] else: raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") - batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64) + batch_shard["max_length_k"] = batch_shard["max_length_q"] = ((max_length + 63) // 64) * 64 combined_batch.append(batch_shard) if self.tp_world_size is not None: @@ -431,6 +464,9 @@ def __init__( self.cp_tp_group = cp_tp_mesh.get_group() self.num_cp_tp_ranks = cp_tp_mesh.size() self._iterator = None + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Any = None + self._cuda_device: int | None = None logger.debug( "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", @@ -442,13 +478,46 @@ def __iter__(self): """Make the dataloader iterable.""" if self.cp_tp_rank == 0: self._iterator = iter(self.dataloader) # < --- collator output. + self.close() + # Capture CUDA device from main thread; torch.cuda.set_device is per-thread, + # so the background thread needs to set it explicitly. + self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None + self._kick_prefetch() return self + @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") def __next__(self): """Get the batch from the dataloader for the current CP rank.""" - batch = self._send_data_to_cp_tp_ranks() - return batch - + self._prefetch_thread.join() + result = self._prefetch_result + if isinstance(result, StopIteration): + self._prefetch_thread = None + raise result + self._kick_prefetch() + return result + + def _kick_prefetch(self): + """Start a background thread to prefetch exactly one batch via scatter.""" + self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True) + self._prefetch_thread.start() + + def _do_one_prefetch(self): + """Fetch one batch in the background. Stores result in _prefetch_result.""" + if self._cuda_device is not None: + torch.cuda.set_device(self._cuda_device) + try: + self._prefetch_result = self._send_data_to_cp_tp_ranks() + except Exception: + # Process group may have been destroyed; signal stop. + self._prefetch_result = StopIteration() + + def close(self): + """Stop the prefetch thread. Must be called before destroy_process_group().""" + if self._prefetch_thread is not None: + self._prefetch_thread.join(timeout=10) + self._prefetch_thread = None + + @nvtx.annotate("ContextParallelDataLoaderWrapper _send_data_to_cp_tp_ranks", color="green") def _send_data_to_cp_tp_ranks(self): """Send data to all the CP/TP ranks. @@ -476,7 +545,8 @@ def _send_data_to_cp_tp_ranks(self): """ try: - combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + with nvtx.annotate("ContextParallelDataLoaderWrapper next batch", color="green"): + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None except StopIteration as ex: # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so # that the dataloader can be restarted. @@ -679,6 +749,7 @@ def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token # TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 # we can replace this with the one in TransformerEngine. +@nvtx.annotate("collator._split_batch_by_cp_rank", color="green") def _split_batch_by_cp_rank( cu_seqlens_padded: torch.Tensor | None, input_ids_padded: torch.Tensor, @@ -852,6 +923,7 @@ class BatchType(TypedDict): pad_between_seqs: bool +@nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green") def _scatter_batch_to_cp_tp_ranks( all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None ) -> BatchType | StopIteration: diff --git a/bionemo-recipes/recipes/esm2_native_te/collator.py b/bionemo-recipes/recipes/esm2_native_te/collator.py index 43158d3f9b..e3cf6a9607 100644 --- a/bionemo-recipes/recipes/esm2_native_te/collator.py +++ b/bionemo-recipes/recipes/esm2_native_te/collator.py @@ -19,10 +19,12 @@ """ import logging +import threading from dataclasses import dataclass, field from typing import Any, TypedDict import datasets +import nvtx import torch from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp from transformers import DataCollator, DataCollatorForLanguageModeling @@ -234,6 +236,33 @@ class TokenPackingDataset(torch.utils.data.IterableDataset): """Whether to drop the last batch if it's less than max_length.""" split_samples: bool = False """Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens.""" + pad_sequences_to_be_divisible_by: int | None = None + """If set, account for per-sequence padding when accumulating batches. + + Each sequence's contribution to the batch length is rounded up to the nearest multiple of this value, + matching the padding behavior of DataCollatorWithFlattening with the same parameter. When used with + split_samples=True, the split point is chosen so that the first part (after padding) exactly fills + the remaining batch capacity. + """ + + def __post_init__(self): + """Validate padding configuration.""" + if ( + self.pad_sequences_to_be_divisible_by is not None + and self.max_tokens_per_batch % self.pad_sequences_to_be_divisible_by != 0 + ): + logger.warning( + "max_tokens_per_batch (%d) is not divisible by pad_sequences_to_be_divisible_by (%d). " + "Batches may not fill to exactly max_tokens_per_batch when split_samples=True.", + self.max_tokens_per_batch, + self.pad_sequences_to_be_divisible_by, + ) + + def _padded_len(self, length: int) -> int: + """Return the padded length of a sequence, rounding up to the nearest multiple of pad_sequences_to_be_divisible_by.""" + if self.pad_sequences_to_be_divisible_by is None: + return length + return -(-length // self.pad_sequences_to_be_divisible_by) * self.pad_sequences_to_be_divisible_by def __iter__(self): """Yield batches of samples, each with a variable number of tokens up to the maximum length. @@ -241,13 +270,16 @@ def __iter__(self): When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting the final sample if needed. The remaining tokens from the split sample start the next batch. + When pad_sequences_to_be_divisible_by is set, each sequence's padded length is used when + accumulating batch sizes, so the total padded length of the batch matches max_tokens_per_batch. + Returns: A generator of batches of samples, each with a variable number of tokens up to the maximum length. """ samples = [] current_length = 0 for sample in iter(self.dataset): - current_length += len(sample["input_ids"]) + current_length += self._padded_len(len(sample["input_ids"])) if current_length == self.max_tokens_per_batch: yield [*samples, sample] samples = [] @@ -261,15 +293,15 @@ def __iter__(self): samples = [sample] else: - # Calculate how many tokens are already in the batch - tokens_in_batch = current_length - len(sample["input_ids"]) + # Calculate how many padded tokens are already in the batch + tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"])) # Calculate how many tokens we can fit from this sample tokens_available = self.max_tokens_per_batch - tokens_in_batch first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) yield [*samples, first_part] samples = [remaining_part] - current_length = len(samples[0]["input_ids"]) + current_length = self._padded_len(len(samples[0]["input_ids"])) else: samples.append(sample) @@ -349,6 +381,9 @@ def __call__(self, features) -> list[dict[str, Any]]: """ batch = self.collator(features) + # Remove the attention mask from the batch, it's not valid for CP. + batch.pop("attention_mask", None) + if self.is_causal_lm: labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) batch["labels"] = labels[..., 1:].contiguous() @@ -376,12 +411,10 @@ def __call__(self, features) -> list[dict[str, Any]]: max_length = seqlens_q.max().item() elif self.qkv_format == "bshd": max_length = batch["input_ids"].shape[1] - # For BSHD context parallelism, we can't handle padding, so we remove the attention mask. - del batch_shard["attention_mask"] else: raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") - batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64) + batch_shard["max_length_k"] = batch_shard["max_length_q"] = ((max_length + 63) // 64) * 64 combined_batch.append(batch_shard) if self.tp_world_size is not None: @@ -431,6 +464,9 @@ def __init__( self.cp_tp_group = cp_tp_mesh.get_group() self.num_cp_tp_ranks = cp_tp_mesh.size() self._iterator = None + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Any = None + self._cuda_device: int | None = None logger.debug( "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", @@ -442,13 +478,46 @@ def __iter__(self): """Make the dataloader iterable.""" if self.cp_tp_rank == 0: self._iterator = iter(self.dataloader) # < --- collator output. + self.close() + # Capture CUDA device from main thread; torch.cuda.set_device is per-thread, + # so the background thread needs to set it explicitly. + self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None + self._kick_prefetch() return self + @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") def __next__(self): """Get the batch from the dataloader for the current CP rank.""" - batch = self._send_data_to_cp_tp_ranks() - return batch - + self._prefetch_thread.join() + result = self._prefetch_result + if isinstance(result, StopIteration): + self._prefetch_thread = None + raise result + self._kick_prefetch() + return result + + def _kick_prefetch(self): + """Start a background thread to prefetch exactly one batch via scatter.""" + self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True) + self._prefetch_thread.start() + + def _do_one_prefetch(self): + """Fetch one batch in the background. Stores result in _prefetch_result.""" + if self._cuda_device is not None: + torch.cuda.set_device(self._cuda_device) + try: + self._prefetch_result = self._send_data_to_cp_tp_ranks() + except Exception: + # Process group may have been destroyed; signal stop. + self._prefetch_result = StopIteration() + + def close(self): + """Stop the prefetch thread. Must be called before destroy_process_group().""" + if self._prefetch_thread is not None: + self._prefetch_thread.join(timeout=10) + self._prefetch_thread = None + + @nvtx.annotate("ContextParallelDataLoaderWrapper _send_data_to_cp_tp_ranks", color="green") def _send_data_to_cp_tp_ranks(self): """Send data to all the CP/TP ranks. @@ -476,7 +545,8 @@ def _send_data_to_cp_tp_ranks(self): """ try: - combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + with nvtx.annotate("ContextParallelDataLoaderWrapper next batch", color="green"): + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None except StopIteration as ex: # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so # that the dataloader can be restarted. @@ -679,6 +749,7 @@ def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token # TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 # we can replace this with the one in TransformerEngine. +@nvtx.annotate("collator._split_batch_by_cp_rank", color="green") def _split_batch_by_cp_rank( cu_seqlens_padded: torch.Tensor | None, input_ids_padded: torch.Tensor, @@ -852,6 +923,7 @@ class BatchType(TypedDict): pad_between_seqs: bool +@nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green") def _scatter_batch_to_cp_tp_ranks( all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None ) -> BatchType | StopIteration: diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index a320f165fe..c19e24e147 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -47,7 +47,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) - self.min_loss = float("inf") + self.min_loss = torch.tensor(float("inf"), device=torch.device(f"cuda:{dist_config.local_rank}")) self.logging_frequency = args.logger.frequency # Track whether to collect memory stats (disabled by default for max performance) @@ -83,7 +83,7 @@ def log_step( step: int, batch: dict[str, torch.Tensor], outputs: MaskedLMOutput, - grad_norm: float, + grad_norm: torch.Tensor, lr: float, ): """Log a step to the logger and wandb. @@ -95,46 +95,47 @@ def log_step( grad_norm: The gradient norm of the step. lr: The learning rate of the step. """ - num_tokens = batch["input_ids"].numel() - # 1 is the padding token for ESM-2. - num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != 1].numel() - - self.min_loss = min(self.min_loss, outputs.loss.item()) - step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter() - - self.metrics["train/loss"].update(outputs.loss) - self.metrics["train/learning_rate"].update(lr) - self.metrics["train/grad_norm"].update(grad_norm) - self.metrics["train/step_time"].update(step_time) - self.metrics["train/tokens_per_second_per_gpu"].update(num_tokens / step_time) - self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time) - self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens / self.logging_frequency) - - # Handle sequence packing for torchmetrics calculation. - if outputs.logits.dim() < 3: - outputs.logits = outputs.logits.unsqueeze(0) - - self.metrics["train/perplexity"].update(outputs.logits, batch["labels"]) - - if self.fp8_stats_enabled: - debug_api.step() - - if step % self.logging_frequency == 0 and step > 0: - memory_allocated = torch.cuda.memory_allocated() / (1024**3) - self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) - self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) - - metrics = self.metrics.compute() - self.metrics.reset() - metrics["train/global_step"] = torch.tensor(step, dtype=torch.int64) - - if self._dist_config.is_main_process(): - wandb.log(metrics, step=step) - self._progress_bar.update(self.logging_frequency) - self._progress_bar.set_postfix({"loss": outputs.loss.item()}) - - if self._dist_config.local_rank == 0: - logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) + with torch.no_grad(): + num_tokens = batch["input_ids"].numel() + # 1 is the padding token for ESM-2. + num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != 1].numel() + + self.min_loss = torch.minimum(self.min_loss, outputs.loss) + step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter() + + self.metrics["train/loss"].update(outputs.loss) + self.metrics["train/learning_rate"].update(lr) + self.metrics["train/grad_norm"].update(grad_norm) + self.metrics["train/step_time"].update(step_time) + self.metrics["train/tokens_per_second_per_gpu"].update(num_tokens / step_time) + self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time) + self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens / self.logging_frequency) + + # Handle sequence packing for torchmetrics calculation. + if outputs.logits.dim() < 3: + outputs.logits = outputs.logits.unsqueeze(0) + + self.metrics["train/perplexity"].update(outputs.logits, batch["labels"]) + + if self.fp8_stats_enabled: + debug_api.step() + + if step % self.logging_frequency == 0 and step > 0: + memory_allocated = torch.cuda.memory_allocated() / (1024**3) + self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + + metrics = self.metrics.compute() + self.metrics.reset() + metrics["train/global_step"] = torch.tensor(step, dtype=torch.int64) + + if self._dist_config.is_main_process(): + wandb.log(metrics, step=step) + self._progress_bar.update(self.logging_frequency) + self._progress_bar.set_postfix({"loss": outputs.loss.item()}) + + if self._dist_config.local_rank == 0: + logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) def finish(self): """Finish the logger and close the progress bar.""" diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 64ff1a97c7..71d7066c6a 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -161,7 +161,7 @@ def main(args: DictConfig) -> float | None: loss.backward() # Compute and clip gradient norms. - total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Step optimizer. optimizer.step() diff --git a/bionemo-recipes/recipes/llama3_native_te/Dockerfile b/bionemo-recipes/recipes/llama3_native_te/Dockerfile index 03eadca61e..da6d873446 100644 --- a/bionemo-recipes/recipes/llama3_native_te/Dockerfile +++ b/bionemo-recipes/recipes/llama3_native_te/Dockerfile @@ -1,5 +1,9 @@ # syntax=docker/dockerfile:1.4 -FROM nvcr.io/nvidia/pytorch:26.01-py3 +# FROM nvcr.io/nvidia/pytorch:26.01-py3 + +# Note: there's currently a bug in GQA + CP on Hopper but it's fixed in cuDNN 9.18. This image is NVIDIA-internal, but +# leaving this as a note in to remove once the bug is fixed in an nvidia/pytorch image. +FROM gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-base RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md index 630adff7c4..b7416929cc 100644 --- a/bionemo-recipes/recipes/llama3_native_te/README.md +++ b/bionemo-recipes/recipes/llama3_native_te/README.md @@ -210,6 +210,66 @@ These examples show how to save and resume your dataloader by passing the datalo and `load_checkpoint_*` functions using the `StatefulDataLoader` class from `torchdata`. See `checkpoint.py` for implementation details. +## Performance Profiling with NVIDIA Nsight Systems + +This recipe includes built-in support for profiling with NVIDIA Nsight Systems, which provides detailed performance +traces including CUDA kernels, CPU activities, memory operations, and NVTX ranges. The profiler allows you to specify +the exact training step range to profile. + +### Basic Usage (Single GPU) + +To profile a training run on a single GPU: + +```bash +nsys profile \ + -o nsight_trace \ + --trace=cuda,nvtx,osrt,cudnn,cublas \ + --pytorch=autograd-nvtx \ + --capture-range=cudaProfilerApi \ + --capture-range-end=stop \ + python train_fsdp2.py \ + profiler.enabled=true \ + profiler.start_step=10 \ + profiler.end_step=15 +``` + +**Profiler Configuration Parameters:** + +- `profiler.enabled`: Enable/disable profiling (default: false) +- `profiler.start_step`: Training step at which to start profiling (default: 10) +- `profiler.end_step`: Training step at which to end profiling (default: 15) + +**Nsight Systems Flags:** + +- `--pytorch=autograd-nvtx`: Adds NVTX markers for PyTorch autograd operations (forward/backward passes, optimizer steps). This helps visualize the training loop structure and identify bottlenecks in the computation graph. +- `--pytorch-backtrace=cuda`: Captures Python backtraces for CUDA kernel launches, helping identify which Python code triggered each kernel. This is invaluable for debugging performance issues and understanding which operations are expensive. +- `--python-sampling=true` (optional): Periodically samples Python call stacks to identify CPU-side bottlenecks. Useful when investigating data loading, preprocessing, or Python overhead. Adds ~5-15% overhead, so only use when needed. + +**Note**: The PyTorch-specific flags (`--pytorch=autograd-nvtx` and `--pytorch-backtrace=cuda`) add minimal overhead but provide significantly more detailed insights into PyTorch operations, making them highly recommended for training workload profiling. Use `--python-sampling=true` only when investigating CPU/Python performance. + +The profiler will start capturing performance data at `start_step` and stop at `end_step`. It's recommended to start profiling after a few steps to allow training to stabilize. + +### Multi-GPU Profiling + +For distributed training, **profiling is only performed on global rank 0** to minimize overhead and avoid redundant data +collection. Other ranks will skip profiling automatically. + +#### Multi-GPU on Single Node + +```bash +nsys profile \ + -o nsight_trace_rank0 \ + --trace=cuda,nvtx,osrt,cudnn,cublas \ + --pytorch=autograd-nvtx \ + --pytorch-backtrace=cuda \ + --capture-range=cudaProfilerApi \ + --capture-range-end=stop \ + torchrun --nproc_per_node=2 train_fsdp2.py \ + profiler.enabled=true +``` + +For more information on Nsight Systems, see the [official documentation](https://docs.nvidia.com/nsight-systems/). + ## Running Inference with the Trained Model Models can be loaded from the final checkpoint directory using the `AutoModelForCausalLM` method (or diff --git a/bionemo-recipes/recipes/llama3_native_te/collator.py b/bionemo-recipes/recipes/llama3_native_te/collator.py index 43158d3f9b..e3cf6a9607 100644 --- a/bionemo-recipes/recipes/llama3_native_te/collator.py +++ b/bionemo-recipes/recipes/llama3_native_te/collator.py @@ -19,10 +19,12 @@ """ import logging +import threading from dataclasses import dataclass, field from typing import Any, TypedDict import datasets +import nvtx import torch from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp from transformers import DataCollator, DataCollatorForLanguageModeling @@ -234,6 +236,33 @@ class TokenPackingDataset(torch.utils.data.IterableDataset): """Whether to drop the last batch if it's less than max_length.""" split_samples: bool = False """Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens.""" + pad_sequences_to_be_divisible_by: int | None = None + """If set, account for per-sequence padding when accumulating batches. + + Each sequence's contribution to the batch length is rounded up to the nearest multiple of this value, + matching the padding behavior of DataCollatorWithFlattening with the same parameter. When used with + split_samples=True, the split point is chosen so that the first part (after padding) exactly fills + the remaining batch capacity. + """ + + def __post_init__(self): + """Validate padding configuration.""" + if ( + self.pad_sequences_to_be_divisible_by is not None + and self.max_tokens_per_batch % self.pad_sequences_to_be_divisible_by != 0 + ): + logger.warning( + "max_tokens_per_batch (%d) is not divisible by pad_sequences_to_be_divisible_by (%d). " + "Batches may not fill to exactly max_tokens_per_batch when split_samples=True.", + self.max_tokens_per_batch, + self.pad_sequences_to_be_divisible_by, + ) + + def _padded_len(self, length: int) -> int: + """Return the padded length of a sequence, rounding up to the nearest multiple of pad_sequences_to_be_divisible_by.""" + if self.pad_sequences_to_be_divisible_by is None: + return length + return -(-length // self.pad_sequences_to_be_divisible_by) * self.pad_sequences_to_be_divisible_by def __iter__(self): """Yield batches of samples, each with a variable number of tokens up to the maximum length. @@ -241,13 +270,16 @@ def __iter__(self): When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting the final sample if needed. The remaining tokens from the split sample start the next batch. + When pad_sequences_to_be_divisible_by is set, each sequence's padded length is used when + accumulating batch sizes, so the total padded length of the batch matches max_tokens_per_batch. + Returns: A generator of batches of samples, each with a variable number of tokens up to the maximum length. """ samples = [] current_length = 0 for sample in iter(self.dataset): - current_length += len(sample["input_ids"]) + current_length += self._padded_len(len(sample["input_ids"])) if current_length == self.max_tokens_per_batch: yield [*samples, sample] samples = [] @@ -261,15 +293,15 @@ def __iter__(self): samples = [sample] else: - # Calculate how many tokens are already in the batch - tokens_in_batch = current_length - len(sample["input_ids"]) + # Calculate how many padded tokens are already in the batch + tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"])) # Calculate how many tokens we can fit from this sample tokens_available = self.max_tokens_per_batch - tokens_in_batch first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) yield [*samples, first_part] samples = [remaining_part] - current_length = len(samples[0]["input_ids"]) + current_length = self._padded_len(len(samples[0]["input_ids"])) else: samples.append(sample) @@ -349,6 +381,9 @@ def __call__(self, features) -> list[dict[str, Any]]: """ batch = self.collator(features) + # Remove the attention mask from the batch, it's not valid for CP. + batch.pop("attention_mask", None) + if self.is_causal_lm: labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) batch["labels"] = labels[..., 1:].contiguous() @@ -376,12 +411,10 @@ def __call__(self, features) -> list[dict[str, Any]]: max_length = seqlens_q.max().item() elif self.qkv_format == "bshd": max_length = batch["input_ids"].shape[1] - # For BSHD context parallelism, we can't handle padding, so we remove the attention mask. - del batch_shard["attention_mask"] else: raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") - batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64) + batch_shard["max_length_k"] = batch_shard["max_length_q"] = ((max_length + 63) // 64) * 64 combined_batch.append(batch_shard) if self.tp_world_size is not None: @@ -431,6 +464,9 @@ def __init__( self.cp_tp_group = cp_tp_mesh.get_group() self.num_cp_tp_ranks = cp_tp_mesh.size() self._iterator = None + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Any = None + self._cuda_device: int | None = None logger.debug( "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", @@ -442,13 +478,46 @@ def __iter__(self): """Make the dataloader iterable.""" if self.cp_tp_rank == 0: self._iterator = iter(self.dataloader) # < --- collator output. + self.close() + # Capture CUDA device from main thread; torch.cuda.set_device is per-thread, + # so the background thread needs to set it explicitly. + self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None + self._kick_prefetch() return self + @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") def __next__(self): """Get the batch from the dataloader for the current CP rank.""" - batch = self._send_data_to_cp_tp_ranks() - return batch - + self._prefetch_thread.join() + result = self._prefetch_result + if isinstance(result, StopIteration): + self._prefetch_thread = None + raise result + self._kick_prefetch() + return result + + def _kick_prefetch(self): + """Start a background thread to prefetch exactly one batch via scatter.""" + self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True) + self._prefetch_thread.start() + + def _do_one_prefetch(self): + """Fetch one batch in the background. Stores result in _prefetch_result.""" + if self._cuda_device is not None: + torch.cuda.set_device(self._cuda_device) + try: + self._prefetch_result = self._send_data_to_cp_tp_ranks() + except Exception: + # Process group may have been destroyed; signal stop. + self._prefetch_result = StopIteration() + + def close(self): + """Stop the prefetch thread. Must be called before destroy_process_group().""" + if self._prefetch_thread is not None: + self._prefetch_thread.join(timeout=10) + self._prefetch_thread = None + + @nvtx.annotate("ContextParallelDataLoaderWrapper _send_data_to_cp_tp_ranks", color="green") def _send_data_to_cp_tp_ranks(self): """Send data to all the CP/TP ranks. @@ -476,7 +545,8 @@ def _send_data_to_cp_tp_ranks(self): """ try: - combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + with nvtx.annotate("ContextParallelDataLoaderWrapper next batch", color="green"): + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None except StopIteration as ex: # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so # that the dataloader can be restarted. @@ -679,6 +749,7 @@ def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token # TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 # we can replace this with the one in TransformerEngine. +@nvtx.annotate("collator._split_batch_by_cp_rank", color="green") def _split_batch_by_cp_rank( cu_seqlens_padded: torch.Tensor | None, input_ids_padded: torch.Tensor, @@ -852,6 +923,7 @@ class BatchType(TypedDict): pad_between_seqs: bool +@nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green") def _scatter_batch_to_cp_tp_ranks( all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None ) -> BatchType | StopIteration: diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_1b.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_1b.yaml index dbba583387..d004c72976 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_1b.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_1b.yaml @@ -55,8 +55,5 @@ checkpoint: profiler: enabled: false - schedule: - wait: 125 - warmup: 125 - active: 10 - repeat: 1 + start_step: 250 # Previously wait + warmup (125 + 125) + end_step: 260 # Previously start_step + active (250 + 10) diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index 0bbaf333e3..d6c181598f 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -77,8 +77,5 @@ fp8_stats_config: profiler: enabled: false - schedule: - wait: 10 - warmup: 10 - active: 3 - repeat: 1 + start_step: 10 + end_step: 15 diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 56ccd7a349..57f2e8bdfd 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -14,16 +14,16 @@ # limitations under the License. import logging +import os import time -from pathlib import Path import nvdlfw_inspect.api as debug_api +import nvtx import torch import torchmetrics import wandb -from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf -from torch.profiler import profile, schedule, tensorboard_trace_handler +from torch.distributed.tensor import DTensor from tqdm import tqdm from transformers.modeling_outputs import CausalLMOutputWithPast @@ -49,7 +49,8 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) - self.min_loss = float("inf") + self._device = torch.device(f"cuda:{dist_config.local_rank}") + self.min_loss = torch.tensor(float("inf"), device=self._device) self.logging_frequency = args.logger.frequency @@ -67,7 +68,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self.metrics = torchmetrics.MetricCollection(metrics_dict) # We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging. - self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}")) + self.metrics.to(self._device) self.previous_step_time = time.perf_counter() self._profiler = None @@ -77,39 +78,53 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self._progress_bar = tqdm(total=args.num_train_steps, desc="Training") if args.profiler.enabled: - self._profiler = setup_profiler(args, self._wandb_run) - self._profiler.__enter__() + self._profiler = NsightProfiler( + **args.profiler, + wandb_run=self._wandb_run, + dist_config=dist_config, + ) # Gradient accumulation tracking self.num_tokens = 0 - self.num_unpadded_tokens = 0 - self.running_loss = 0.0 + self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device) + self.running_loss = torch.tensor(0.0, device=self._device) self.grad_acc_step_count = 0 # Whether to step debug_api.step() after each step self.fp8_stats_enabled = args.fp8_stats_config.enabled - def log_micro_step(self, batch: dict[str, torch.Tensor], outputs: CausalLMOutputWithPast): + @nvtx.annotate("PerfLogger.log_micro_step", color="pink") + def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: CausalLMOutputWithPast): """Store data on micro step for gradient accumulation metrics. Args: + step: The step number. batch: The batch of data for the micro step. outputs: The outputs of the micro step. """ - self.grad_acc_step_count += 1 - self.num_tokens += batch["input_ids"].numel() - # Use attention_mask to count unpadded tokens (works for both BSHD and THD) - if "attention_mask" in batch: - self.num_unpadded_tokens += batch["attention_mask"].sum().item() - else: - # Fallback for pure sequence packing with no padding: all tokens are unpadded - self.num_unpadded_tokens += batch["input_ids"].numel() - self.running_loss += outputs.loss.item() + if self._dist_config.local_rank == 0: + logger.debug("log_micro_step") + + assert outputs.loss is not None, "Loss is None" + with torch.no_grad(): + self.grad_acc_step_count += 1 + self.running_loss += outputs.loss + + if step % self.logging_frequency == 0 and step > 0: + self.num_tokens += batch["input_ids"].numel() + # Use attention_mask to count unpadded tokens (works for both BSHD and THD) + if "attention_mask" in batch: + self.num_unpadded_tokens += batch["attention_mask"].sum() + else: + # Fallback for pure sequence packing with no padding: all tokens are unpadded + self.num_unpadded_tokens += batch["input_ids"].numel() + + @nvtx.annotate("PerfLogger.log_step", color="purple") def log_step( self, step: int, - grad_norm: float, + grad_norm: torch.Tensor | DTensor, lr: float, ): """Log a step to the logger and wandb. @@ -119,58 +134,68 @@ def log_step( grad_norm: The gradient norm of the step. lr: The learning rate of the step. """ - # Use accumulated metrics from gradient accumulation - assert self.grad_acc_step_count > 0, ( - f"Gradient accumulation steps ({self.grad_acc_step_count}) must be greater than 0, " - f"and can be incremented by log_micro_step()." - ) - - avg_loss = self.running_loss / self.grad_acc_step_count - self.min_loss = min(self.min_loss, avg_loss) - step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter() - - self.metrics["train/loss"].update(avg_loss) - self.metrics["train/learning_rate"].update(lr) - self.metrics["train/grad_norm"].update(grad_norm) - self.metrics["train/step_time"].update(step_time) - self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time) - self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) - self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens / self.logging_frequency) - - if self._profiler is not None: - self._profiler.step() - - if self.fp8_stats_enabled: - debug_api.step() - - if step % self.logging_frequency == 0 and step > 0: - memory_allocated = torch.cuda.memory_allocated() / (1024**3) - self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) - self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) - - metrics = self.metrics.compute() - self.metrics.reset() - metrics["train/global_step"] = torch.tensor(step, dtype=torch.int64) - - if self._dist_config.is_main_process(): - wandb.log(metrics, step=step) - self._progress_bar.update(self.logging_frequency) - self._progress_bar.set_postfix({"loss": avg_loss}) - - if self._dist_config.local_rank == 0: - logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) - - # Reset gradient accumulation tracking for next step - self.num_tokens = 0 - self.num_unpadded_tokens = 0 - self.running_loss = 0.0 - self.grad_acc_step_count = 0 + if self._dist_config.local_rank == 0: + logger.debug("log_step %s", step) + + with torch.no_grad(): + # Use accumulated metrics from gradient accumulation + assert self.grad_acc_step_count > 0, ( + f"Gradient accumulation steps ({self.grad_acc_step_count}) must be greater than 0, " + f"and can be incremented by log_micro_step()." + ) + + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.to_local() + + now = time.perf_counter() + step_time = now - self.previous_step_time + self.previous_step_time = now + + if self._profiler is not None: + self._profiler.step(step) + + if self.fp8_stats_enabled: + debug_api.step() + + if step % self.logging_frequency == 0 and step > 0: + # Calculate average loss over all micro steps in the logging window + avg_loss = self.running_loss / self.grad_acc_step_count + self.min_loss = torch.minimum(self.min_loss, avg_loss) + + # For some reason, these trigger a CudaStreamSynchronize call, which blocks the dataloader in the next + # step. We therefore only update these once every logging_frequency steps. + self.metrics["train/loss"].update(avg_loss) + self.metrics["train/learning_rate"].update(lr) + self.metrics["train/grad_norm"].update(grad_norm) + self.metrics["train/step_time"].update(step_time) + self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time) + self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time) + self.metrics["train/total_unpadded_tokens_per_batch"].update(self.num_unpadded_tokens) + + memory_allocated = torch.cuda.memory_allocated() / (1024**3) + self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) + self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) + + metrics = self.metrics.compute() + self.metrics.reset() + metrics["train/global_step"] = torch.tensor(step, dtype=torch.int64) + + if self._dist_config.is_main_process(): + wandb.log(metrics, step=step) + self._progress_bar.update(self.logging_frequency) + self._progress_bar.set_postfix({"loss": avg_loss.item()}) + + if self._dist_config.local_rank == 0: + logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) + + # Reset running loss and other tracking variables for next window + self.running_loss.zero_() + self.num_tokens = 0 + self.num_unpadded_tokens.zero_() + self.grad_acc_step_count = 0 def finish(self): """Finish the logger and close the progress bar.""" - if self._profiler is not None: - self._profiler.__exit__(None, None, None) - if not self._dist_config.is_main_process(): return @@ -181,39 +206,99 @@ def finish(self): debug_api.end_debug() -def setup_profiler(args: DictConfig, wandb_run: wandb.Run): - """Setup a basic torch profiler for the experiment. +class NsightProfiler: + """Nsight Systems profiler wrapper for performance analysis. + + This profiler uses NVIDIA Nsight Systems to capture detailed performance traces + including CUDA kernels, CPU activities, and memory operations. The profiler + uploads results to wandb as artifacts. Args: - args: The arguments. - wandb_run: The wandb run. + enabled: Whether profiling is enabled. + start_step: The step number at which to start profiling. + end_step: The step number at which to end profiling. + wandb_run: The wandb run for logging artifacts. + dist_config: The distributed configuration. - Returns: - The profiler. + Attributes: + start_step: The step number at which to start profiling. + end_step: The step number at which to end profiling. + current_step: Current step counter. + profiling_started: Whether profiling has been started. + profiling_finished: Whether profiling has been finished. """ - _trace_dir = Path(HydraConfig.get().runtime.output_dir) / "traces" - _trace_dir.mkdir(parents=True, exist_ok=True) - - def on_trace_ready(prof): - """Custom callback to save chrome trace, export memory timeline, and log to wandb.""" - # Save chrome trace using tensorboard_trace_handler - tensorboard_trace_handler(str(_trace_dir))(prof) - # Export memory timeline - prof.export_memory_timeline(str(_trace_dir / "memory_timeline.html"), device="cuda:0") - # Log artifacts to wandb - profile_art = wandb.Artifact(name=f"{wandb_run.name}_profile", type="profile") - for file in _trace_dir.glob("*.json"): - profile_art.add_file(str(file), name=file.name) - profile_art.add_file(str(_trace_dir / "memory_timeline.html"), name="memory_timeline.html") - wandb_run.log_artifact(profile_art) - - return profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], - schedule=schedule(**args.profiler.schedule), - on_trace_ready=on_trace_ready, - with_stack=True, - with_flops=True, - with_modules=True, - profile_memory=True, - record_shapes=True, - ) + + def __init__( + self, + enabled: bool, + start_step: int, + end_step: int, + wandb_run: wandb.Run, + dist_config: DistributedConfig, + ): + """Initialize the Nsight profiler.""" + self._wandb_run = wandb_run + self._dist_config = dist_config + + self.start_step = start_step + self.end_step = end_step + + self.current_step = 0 + self.profiling_started = False + self.profiling_finished = False + + # Check if running under nsys + self.running_under_nsys = "NSYS_PROFILING_SESSION_ID" in os.environ + + if self.running_under_nsys: + logger.info("Detected running under nsys - will use CUDA Profiler API for range control") + else: + logger.warning( + "Not running under nsys. Profiling will be skipped. " + "To enable profiling, run your script with: " + "nsys profile -o output_trace --trace=cuda,nvtx,osrt,cudnn,cublas --capture-range=cudaProfilerApi " + "--capture-range-end=stop python train_fsdp2.py profiler.enabled=true" + ) + + def step(self, step_num: int): + """Record a training step and control profiling based on the schedule. + + Args: + step_num: The current training step number. + """ + if not self.running_under_nsys or self.profiling_finished: + return + + self.current_step = step_num + + # Start profiling at start_step + if self.current_step == self.start_step and not self.profiling_started: + self._start_profiling() + # Stop profiling at end_step + elif self.current_step == self.end_step and self.profiling_started: + self._stop_profiling() + + def _start_profiling(self): + """Start CUDA profiling using the CUDA Profiler API.""" + if self.profiling_started: + return + + logger.info(f"Starting Nsight profiling at step {self.current_step}") + try: + torch.cuda.cudart().cudaProfilerStart() # type: ignore[attr-defined] + self.profiling_started = True + except Exception as e: + logger.error(f"Failed to start CUDA profiler: {e}") + + def _stop_profiling(self): + """Stop CUDA profiling using the CUDA Profiler API.""" + if not self.profiling_started or self.profiling_finished: + return + + logger.info(f"Stopping Nsight profiling at step {self.current_step}") + try: + torch.cuda.cudart().cudaProfilerStop() # type: ignore[attr-defined] + self.profiling_started = False + self.profiling_finished = True + except Exception as e: + logger.error(f"Failed to stop CUDA profiler: {e}") diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py index 414fa0418d..0cbaa64673 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_dataset.py @@ -738,7 +738,6 @@ def test_cp_dataloader(tokenizer_path): "input_ids", "cu_seq_lens_q", "cu_seq_lens_k", - "attention_mask", "labels", "cu_seq_lens_q_padded", "cu_seq_lens_k_padded", @@ -762,6 +761,7 @@ def test_cp_dataloader_multi_gpu(recipe_path, dataset_path): cmd = [ "torchrun", + "--standalone", "--nproc_per_node=2", "tests/test_dataset.py", "--dataset_path", @@ -854,4 +854,5 @@ def test_cp_dataloader_multi_gpu(recipe_path, dataset_path): assert batch["labels"] is None assert batch["shift_labels"].shape[1] == actual_shape + dataloader.close() torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py index 1a48c4134f..098223291d 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py @@ -202,6 +202,7 @@ def test_checkpoint_save_and_load_two_processes_ddp(recipe_path, tmp_path): # Phase 1: Train for 10 steps with 2 processes cmd_phase1 = [ "torchrun", + "--standalone", "--nproc_per_node=2", str(train_script), f"checkpoint.ckpt_dir={temp_dir}", @@ -265,6 +266,7 @@ def test_checkpoint_save_and_load_two_processes_ddp(recipe_path, tmp_path): # Phase 2: Resume training with 2 processes cmd_phase2 = [ "torchrun", + "--standalone", "--nproc_per_node=2", str(train_script), f"checkpoint.ckpt_dir={temp_dir}", @@ -461,6 +463,7 @@ def test_checkpoint_save_and_load_two_processes_fsdp2(recipe_path, tmp_path): # Phase 1: Train for 10 steps with 2 processes cmd_phase1 = [ "torchrun", + "--standalone", "--nproc_per_node=2", str(train_script), f"checkpoint.ckpt_dir={temp_dir}", @@ -506,6 +509,7 @@ def test_checkpoint_save_and_load_two_processes_fsdp2(recipe_path, tmp_path): # Phase 2: Resume training with 2 processes cmd_phase2 = [ "torchrun", + "--standalone", "--nproc_per_node=2", str(train_script), f"checkpoint.ckpt_dir={temp_dir}", @@ -679,6 +683,7 @@ def test_checkpoint_save_and_load_two_processes_fsdp2_with_context_parallelism(r # Phase 1: Train for 10 steps with 2 processes cmd_phase1 = [ "torchrun", + "--standalone", "--nproc_per_node=2", str(train_script), f"checkpoint.ckpt_dir={temp_dir}", @@ -726,6 +731,7 @@ def test_checkpoint_save_and_load_two_processes_fsdp2_with_context_parallelism(r # Phase 2: Resume training with 2 processes cmd_phase2 = [ "torchrun", + "--standalone", "--nproc_per_node=2", str(train_script), f"checkpoint.ckpt_dir={temp_dir}", @@ -950,6 +956,7 @@ def test_scheduler_resume_two_gpu(recipe_path, tmp_path): # Phase 1: Train for 10 steps with 2 GPUs cmd_phase1 = [ "torchrun", + "--standalone", "--nproc_per_node=2", str(train_script), f"checkpoint.ckpt_dir={temp_dir}", @@ -974,6 +981,7 @@ def test_scheduler_resume_two_gpu(recipe_path, tmp_path): # Phase 2: Resume training with 2 GPUs cmd_phase2 = [ "torchrun", + "--standalone", "--nproc_per_node=2", str(train_script), f"checkpoint.ckpt_dir={temp_dir}", diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py new file mode 100644 index 0000000000..19af95ccb1 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for PerfLogger loss calculation correctness.""" + +from unittest import mock + +import pytest +import torch +from omegaconf import OmegaConf +from transformers.modeling_outputs import CausalLMOutputWithPast + +from distributed_config import DistributedConfig +from perf_logger import PerfLogger + + +def _make_args(logging_frequency=1, num_train_steps=100): + """Create a minimal args config for PerfLogger.""" + return OmegaConf.create( + { + "logger": {"frequency": logging_frequency}, + "wandb": {"project": "test", "mode": "disabled"}, + "num_train_steps": num_train_steps, + "profiler": {"enabled": False}, + "fp8_stats_config": {"enabled": False}, + } + ) + + +def _make_batch(seq_len=128, device="cuda:0"): + """Create a minimal batch dict.""" + return { + "input_ids": torch.ones(1, seq_len, dtype=torch.long, device=device), + "attention_mask": torch.ones(1, seq_len, dtype=torch.long, device=device), + } + + +def _make_outputs(loss_value, device="cuda:0"): + """Create CausalLMOutputWithPast with a given loss.""" + return CausalLMOutputWithPast(loss=torch.tensor(loss_value, device=device)) + + +@pytest.fixture +def mock_wandb(): + """Mock wandb to prevent actual logging.""" + with mock.patch("perf_logger.wandb") as mocked: + mocked.init.return_value = mock.MagicMock() + yield mocked + + +@pytest.fixture +def mock_tqdm(): + """Mock tqdm to prevent progress bar output.""" + with mock.patch("perf_logger.tqdm") as mocked: + yield mocked + + +def _create_perf_logger(logging_frequency, mock_wandb, mock_tqdm): + """Create a PerfLogger with the given logging_frequency.""" + dist_config = DistributedConfig() + args = _make_args(logging_frequency=logging_frequency) + return PerfLogger(dist_config, args) + + +def _run_steps(perf_logger, losses, grad_acc_steps=1): + """Simulate training steps with given per-optimizer-step losses. + + Args: + perf_logger: The PerfLogger instance. + losses: List of loss values, one per optimizer step. With grad_acc_steps>1, + each value is used for all micro steps in that optimizer step. + grad_acc_steps: Number of micro steps per optimizer step. + """ + device = perf_logger._device + for step_idx, loss_val in enumerate(losses): + step = step_idx + 1 + batch = _make_batch(device=device) + outputs = _make_outputs(loss_val, device=device) + for _ in range(grad_acc_steps): + perf_logger.log_micro_step(step, batch, outputs) + perf_logger.log_step(step, torch.tensor(1.0, device=device), 1e-4) + + +def _get_logged_losses(mock_wandb): + """Extract reported loss values from wandb.log calls.""" + return [call[0][0]["train/loss"].item() for call in mock_wandb.log.call_args_list] + + +class TestPerfLoggerLoss: + """Test that PerfLogger computes average loss correctly.""" + + def test_logging_frequency_1_reports_each_loss(self, mock_wandb, mock_tqdm): + """With logging_frequency=1, each step's loss should be reported as-is.""" + perf_logger = _create_perf_logger(1, mock_wandb, mock_tqdm) + losses = [1.0, 2.0, 3.0, 4.0, 5.0] + _run_steps(perf_logger, losses) + + reported = _get_logged_losses(mock_wandb) + assert len(reported) == len(losses) + for i, (got, expected) in enumerate(zip(reported, losses)): + assert got == pytest.approx(expected), f"Step {i + 1}: expected {expected}, got {got}" + + def test_logging_frequency_5_matches_averaged_frequency_1(self, mock_wandb, mock_tqdm): + """logging_frequency=5 should report the same average as manually averaging 5 frequency-1 losses.""" + losses = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + logging_freq = 5 + + # Run with logging_frequency=1 + perf_logger_1 = _create_perf_logger(1, mock_wandb, mock_tqdm) + _run_steps(perf_logger_1, losses) + freq1_losses = _get_logged_losses(mock_wandb) + assert len(freq1_losses) == 10 + + # Compute expected averages over windows of size logging_freq + expected = [] + for i in range(0, len(freq1_losses), logging_freq): + window = freq1_losses[i : i + logging_freq] + expected.append(sum(window) / len(window)) + + # Run with logging_frequency=5 + mock_wandb.log.reset_mock() + perf_logger_5 = _create_perf_logger(logging_freq, mock_wandb, mock_tqdm) + _run_steps(perf_logger_5, losses) + freq5_losses = _get_logged_losses(mock_wandb) + + assert len(freq5_losses) == len(expected), f"Expected {len(expected)} log events, got {len(freq5_losses)}" + for i, (got, exp) in enumerate(zip(freq5_losses, expected)): + assert got == pytest.approx(exp), f"Window {i}: expected {exp}, got {got}" + + def test_logging_frequency_with_grad_accumulation(self, mock_wandb, mock_tqdm): + """Loss should be correct when combining gradient accumulation with logging_frequency > 1.""" + grad_acc_steps = 4 + logging_freq = 3 + # Each value is used for all micro steps in that optimizer step + losses = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + + # Run with logging_frequency=1 to get per-step losses + perf_logger_1 = _create_perf_logger(1, mock_wandb, mock_tqdm) + _run_steps(perf_logger_1, losses, grad_acc_steps=grad_acc_steps) + freq1_losses = _get_logged_losses(mock_wandb) + assert len(freq1_losses) == 6 + + # Each step's loss should equal the input loss (all micro steps have same value) + for i, (got, expected) in enumerate(zip(freq1_losses, losses)): + assert got == pytest.approx(expected), f"Step {i + 1}: expected {expected}, got {got}" + + # Compute expected averages + expected = [] + for i in range(0, len(freq1_losses), logging_freq): + window = freq1_losses[i : i + logging_freq] + expected.append(sum(window) / len(window)) + + # Run with logging_frequency=logging_freq + mock_wandb.log.reset_mock() + perf_logger_n = _create_perf_logger(logging_freq, mock_wandb, mock_tqdm) + _run_steps(perf_logger_n, losses, grad_acc_steps=grad_acc_steps) + freqn_losses = _get_logged_losses(mock_wandb) + + assert len(freqn_losses) == len(expected) + for i, (got, exp) in enumerate(zip(freqn_losses, expected)): + assert got == pytest.approx(exp), f"Window {i}: expected {exp}, got {got}" + + def test_logging_frequency_with_varying_micro_losses(self, mock_wandb, mock_tqdm): + """Test with different loss values across micro steps within a single optimizer step.""" + logging_freq = 2 + device = torch.device("cuda:0") + + perf_logger = _create_perf_logger(logging_freq, mock_wandb, mock_tqdm) + + # Step 1: micro losses [1.0, 3.0] → avg micro loss = 2.0 + for loss_val in [1.0, 3.0]: + batch = _make_batch(device=device) + outputs = _make_outputs(loss_val, device=device) + perf_logger.log_micro_step(1, batch, outputs) + perf_logger.log_step(1, torch.tensor(1.0, device=device), 1e-4) + + # Step 2: micro losses [5.0, 7.0] → avg micro loss = 6.0 + # Window of 2 steps: avg = (2.0 + 6.0) / 2 = 4.0 + for loss_val in [5.0, 7.0]: + batch = _make_batch(device=device) + outputs = _make_outputs(loss_val, device=device) + perf_logger.log_micro_step(2, batch, outputs) + perf_logger.log_step(2, torch.tensor(1.0, device=device), 1e-4) + + reported = _get_logged_losses(mock_wandb) + assert len(reported) == 1 + # Total running_loss = 1.0 + 3.0 + 5.0 + 7.0 = 16.0 + # grad_acc_step_count = 4 (2 micro steps * 2 optimizer steps) + # avg = 16.0 / 4 = 4.0 + assert reported[0] == pytest.approx(4.0), f"Expected 4.0, got {reported[0]}" + + def test_min_loss_tracked_correctly(self, mock_wandb, mock_tqdm): + """min_loss should track the true minimum average loss across windows.""" + perf_logger = _create_perf_logger(1, mock_wandb, mock_tqdm) + losses = [5.0, 2.0, 8.0, 1.0, 4.0] + _run_steps(perf_logger, losses) + + assert perf_logger.min_loss.item() == pytest.approx(1.0) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index dae01e7603..0fb725092a 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -15,6 +15,7 @@ import gc import random +import subprocess import pytest import torch @@ -504,3 +505,75 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path): assert fp8_log_dir.exists() assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log").exists() assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists() + + +def run_train_cmd(cmd, recipe_path): + """Run a training command and check for errors. + + Args: + cmd: List of command arguments to run + recipe_path: Path to the recipe directory (working directory for command) + + Raises: + pytest.fail: If command returns non-zero exit code + """ + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, # 4 minutes timeout + cwd=str(recipe_path), + ) + + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}") + + +nsys_available = subprocess.run(["which", "nsys"], check=False, capture_output=True).returncode == 0 + + +@pytest.mark.skipif(not nsys_available, reason="nsys not available in environment") +def test_nsight_profiler_trace_generation(tmp_path, recipe_path): + """Test that Nsight profiler is configured correctly and generates trace metadata. + + This test validates: + - The profiler can be enabled through configuration + - The profiler runs without errors during training + - Training under nsys produces .nsys-rep trace files + - The profiler correctly detects whether it's running under nsys + """ + nsys_output_path = tmp_path / "nsys_profile" + + run_train_cmd( + [ + "nsys", + "profile", + "-o", + str(nsys_output_path), + "--trace=cuda,nvtx", + "--pytorch=autograd-nvtx", + "--python-sampling=true", + "--capture-range=cudaProfilerApi", + "--capture-range-end=stop", + "torchrun", + "--standalone", + "--nproc_per_node=1", + "train_ddp.py", + "--config-name", + "L0_sanity", + "num_train_steps=4", + "profiler.enabled=true", + "profiler.start_step=1", + "profiler.end_step=3", + f"checkpoint.ckpt_dir={tmp_path}", + ], + recipe_path, + ) + + # Verify nsys trace file was created + nsys_files = list(tmp_path.glob("nsys_profile*.nsys-rep")) + assert len(nsys_files) > 0, f"No .nsys-rep files found in {tmp_path}" diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py index f87915af06..d64e6e4eed 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py @@ -74,7 +74,7 @@ def run_train_cmd(cmd, recipe_path): @requires_multi_gpu -def test_multi_gpu_train_ddp(tmp_path, recipe_path): +def test_multi_gpu_train_ddp(recipe_path): """Test DDP training on 2 GPUs. This test validates: @@ -88,6 +88,7 @@ def test_multi_gpu_train_ddp(tmp_path, recipe_path): run_train_cmd( [ "torchrun", + "--standalone", "--nproc_per_node", "2", # 2 processes = 2 GPUs "--standalone", # Single node mode @@ -101,7 +102,7 @@ def test_multi_gpu_train_ddp(tmp_path, recipe_path): @requires_multi_gpu -def test_multi_gpu_train_fsdp2(tmp_path, recipe_path): +def test_multi_gpu_train_fsdp2(recipe_path): """Test FSDP2 training on 2 GPUs. This test validates: @@ -115,6 +116,7 @@ def test_multi_gpu_train_fsdp2(tmp_path, recipe_path): run_train_cmd( [ "torchrun", + "--standalone", "--nproc_per_node", "2", # 2 processes = 2 GPUs "--standalone", # Single node mode @@ -139,6 +141,7 @@ def test_multi_gpu_train_ddp_with_checkpointing(tmp_path, recipe_path): run_train_cmd( [ "torchrun", + "--standalone", "--nproc_per_node", "2", "--standalone", @@ -171,6 +174,7 @@ def test_multi_gpu_train_fsdp2_with_checkpointing(tmp_path, recipe_path): run_train_cmd( [ "torchrun", + "--standalone", "--nproc_per_node", "2", "--standalone", @@ -196,6 +200,7 @@ def test_multi_gpu_train_te_fsdp2_cp_bshd(tmp_path, recipe_path): run_train_cmd( [ "torchrun", + "--standalone", "--nproc_per_node=2", "--standalone", "train_fsdp2_cp.py", @@ -219,6 +224,7 @@ def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path): run_train_cmd( [ "torchrun", + "--standalone", "--nproc_per_node=2", "--standalone", "train_fsdp2_cp.py", @@ -234,3 +240,50 @@ def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path): ], recipe_path, ) + + +nsys_available = subprocess.run(["which", "nsys"], check=False, capture_output=True).returncode == 0 + + +@pytest.mark.skipif(not nsys_available, reason="nsys not available in environment") +@requires_multi_gpu +def test_nsight_profiler_trace_generation_two_gpu(tmp_path, recipe_path): + """Test that Nsight profiler is configured correctly and generates trace metadata. + + This test validates: + - The profiler can be enabled through configuration + - The profiler runs without errors during training + - Training under nsys produces .nsys-rep trace files + - The profiler correctly detects whether it's running under nsys + """ + nsys_output_path = tmp_path / "nsys_profile" + + run_train_cmd( + [ + "nsys", + "profile", + "-o", + str(nsys_output_path), + "--trace=cuda,nvtx", + "--pytorch=autograd-nvtx", + "--python-sampling=true", + "--capture-range=cudaProfilerApi", + "--capture-range-end=stop", + "torchrun", + "--standalone", + "--nproc_per_node=2", + "train_ddp.py", + "--config-name", + "L0_sanity", + "num_train_steps=4", + "profiler.enabled=true", + "profiler.start_step=1", + "profiler.end_step=3", + f"checkpoint.ckpt_dir={tmp_path}", + ], + recipe_path, + ) + + # Verify nsys trace file was created + nsys_files = list(tmp_path.glob("nsys_profile*.nsys-rep")) + assert len(nsys_files) > 0, f"No .nsys-rep files found in {tmp_path}" diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index 72f6adf109..49d22901d3 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -150,7 +150,7 @@ def main(args: DictConfig) -> float | None: loss.backward() # Log microbatch step data for accumulation metrics - perf_logger.log_micro_step(batch=batch, outputs=outputs) + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) # Gradient accumulation - only step optimizer after accumulating gradients if micro_step % args.grad_acc_steps == 0: diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 35e894a787..be3e80a514 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -174,14 +174,14 @@ def main(args: DictConfig) -> float | None: loss.backward() # Log microbatch step data for accumulation metrics - perf_logger.log_micro_step(batch=batch, outputs=outputs) + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) # Gradient accumulation - only step optimizer after accumulating gradients if micro_step % args.grad_acc_steps == 0: micro_step = 0 # Compute and clip gradient norms. - total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Step optimizer. optimizer.step() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index 0a70bee2a6..742e21a63d 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -19,6 +19,7 @@ from pathlib import Path import hydra +import nvtx import torch import transformer_engine.pytorch from omegaconf import DictConfig, OmegaConf @@ -114,7 +115,9 @@ def main(args: DictConfig) -> float | None: # Create the context-aware dataloader. We only create the dataloader on rank 0 and wrap it in a # ContextParallelDataLoaderWrapper that will shard and distribute the data across the context parallelism group. - args.dataset.setdefault("pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2) + if args.dataset.get("pad_sequences_to_be_divisible_by", None) is None: + logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2") + OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2) if device_mesh["cp"].get_local_rank() == 0: if args.use_sequence_packing: train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) @@ -169,22 +172,25 @@ def main(args: DictConfig) -> float | None: micro_step += 1 # Forward pass with mixed precision. - with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): - outputs = model(**batch) + with nvtx.annotate("Forward pass", color="green"): + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + outputs = model(**batch) # Backward pass - scale loss by grad_acc_steps for proper gradient averaging loss = outputs.loss / args.grad_acc_steps - loss.backward() + + with nvtx.annotate("Backward pass", color="red"): + loss.backward() # Log microbatch step data for accumulation metrics - perf_logger.log_micro_step(batch=batch, outputs=outputs) + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) # Gradient accumulation - only step optimizer after accumulating gradients if micro_step % args.grad_acc_steps == 0: micro_step = 0 # Compute and clip gradient norms. - total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Step optimizer. optimizer.step() @@ -212,9 +218,9 @@ def main(args: DictConfig) -> float | None: async_save=args.checkpoint.async_save, ) - step += 1 - if step >= args.num_train_steps: - break + step += 1 + if step >= args.num_train_steps: + break # Dataloader exhausted, incrementing epoch epoch += 1 From 3562e75c1956cee214935706edf0ed95c9cbe9f6 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 11 Feb 2026 09:18:05 -0800 Subject: [PATCH 2/9] revert dockerfile to public image Signed-off-by: Peter St. John --- bionemo-recipes/recipes/llama3_native_te/Dockerfile | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/Dockerfile b/bionemo-recipes/recipes/llama3_native_te/Dockerfile index da6d873446..03eadca61e 100644 --- a/bionemo-recipes/recipes/llama3_native_te/Dockerfile +++ b/bionemo-recipes/recipes/llama3_native_te/Dockerfile @@ -1,9 +1,5 @@ # syntax=docker/dockerfile:1.4 -# FROM nvcr.io/nvidia/pytorch:26.01-py3 - -# Note: there's currently a bug in GQA + CP on Hopper but it's fixed in cuDNN 9.18. This image is NVIDIA-internal, but -# leaving this as a note in to remove once the bug is fixed in an nvidia/pytorch image. -FROM gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-base +FROM nvcr.io/nvidia/pytorch:26.01-py3 RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ From 1ded8ea0206b7b093333951a98405edb48cd9c10 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 11 Feb 2026 09:40:40 -0800 Subject: [PATCH 3/9] addressing coderabbit review Signed-off-by: Peter St. John --- .../models/esm2/src/esm/collator.py | 19 ++++++++++++++----- .../tests/test_collator_context_parallel.py | 8 ++++++-- bionemo-recipes/models/llama3/collator.py | 19 ++++++++++++++----- .../recipes/esm2_native_te/collator.py | 19 ++++++++++++++----- .../recipes/esm2_native_te/perf_logger.py | 6 +++++- .../recipes/llama3_native_te/collator.py | 19 ++++++++++++++----- .../recipes/llama3_native_te/perf_logger.py | 15 ++++++++++----- 7 files changed, 77 insertions(+), 28 deletions(-) diff --git a/bionemo-recipes/models/esm2/src/esm/collator.py b/bionemo-recipes/models/esm2/src/esm/collator.py index e3cf6a9607..41ec600511 100644 --- a/bionemo-recipes/models/esm2/src/esm/collator.py +++ b/bionemo-recipes/models/esm2/src/esm/collator.py @@ -295,8 +295,12 @@ def __iter__(self): else: # Calculate how many padded tokens are already in the batch tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"])) - # Calculate how many tokens we can fit from this sample + # Calculate how many tokens we can fit from this sample, ensuring the + # padded length doesn't exceed the remaining capacity. tokens_available = self.max_tokens_per_batch - tokens_in_batch + if self.pad_sequences_to_be_divisible_by is not None: + d = self.pad_sequences_to_be_divisible_by + tokens_available = (tokens_available // d) * d first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) yield [*samples, first_part] samples = [remaining_part] @@ -488,11 +492,15 @@ def __iter__(self): @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") def __next__(self): """Get the batch from the dataloader for the current CP rank.""" - self._prefetch_thread.join() + if self._prefetch_thread is not None: + self._prefetch_thread.join() result = self._prefetch_result if isinstance(result, StopIteration): self._prefetch_thread = None raise result + if isinstance(result, Exception): + self._prefetch_thread = None + raise result self._kick_prefetch() return result @@ -507,9 +515,10 @@ def _do_one_prefetch(self): torch.cuda.set_device(self._cuda_device) try: self._prefetch_result = self._send_data_to_cp_tp_ranks() - except Exception: - # Process group may have been destroyed; signal stop. - self._prefetch_result = StopIteration() + except StopIteration as e: + self._prefetch_result = e + except Exception as e: + self._prefetch_result = e def close(self): """Stop the prefetch thread. Must be called before destroy_process_group().""" diff --git a/bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py b/bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py index fe8991b42d..75314b7d08 100644 --- a/bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py +++ b/bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py @@ -1068,7 +1068,9 @@ def test_data_collator_for_context_parallel_thd_correctness(tokenizer): ) # Create the context parallel collator - cp_collator = DataCollatorForContextParallel(collator=base_collator, cp_world_size=cp_world_size, qkv_format="thd") + cp_collator = DataCollatorForContextParallel( + collator=base_collator, device_mesh=_DummyCollatorMesh(cp_size=cp_world_size), qkv_format="thd" + ) # Create test sequences - 8 tokens each for easy division features = [ @@ -1148,7 +1150,9 @@ def test_data_collator_for_context_parallel_thd_max_length_rounding(tokenizer, m collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False), pad_sequences_to_be_divisible_by=divisibility_factor, ) - cp_collator = DataCollatorForContextParallel(collator=base_collator, cp_world_size=cp_world_size, qkv_format="thd") + cp_collator = DataCollatorForContextParallel( + collator=base_collator, device_mesh=_DummyCollatorMesh(cp_size=cp_world_size), qkv_format="thd" + ) # Use a single sequence to ensure max_seqlen is exactly what we expect after padding features = [{"input_ids": input_ids}] diff --git a/bionemo-recipes/models/llama3/collator.py b/bionemo-recipes/models/llama3/collator.py index e3cf6a9607..41ec600511 100644 --- a/bionemo-recipes/models/llama3/collator.py +++ b/bionemo-recipes/models/llama3/collator.py @@ -295,8 +295,12 @@ def __iter__(self): else: # Calculate how many padded tokens are already in the batch tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"])) - # Calculate how many tokens we can fit from this sample + # Calculate how many tokens we can fit from this sample, ensuring the + # padded length doesn't exceed the remaining capacity. tokens_available = self.max_tokens_per_batch - tokens_in_batch + if self.pad_sequences_to_be_divisible_by is not None: + d = self.pad_sequences_to_be_divisible_by + tokens_available = (tokens_available // d) * d first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) yield [*samples, first_part] samples = [remaining_part] @@ -488,11 +492,15 @@ def __iter__(self): @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") def __next__(self): """Get the batch from the dataloader for the current CP rank.""" - self._prefetch_thread.join() + if self._prefetch_thread is not None: + self._prefetch_thread.join() result = self._prefetch_result if isinstance(result, StopIteration): self._prefetch_thread = None raise result + if isinstance(result, Exception): + self._prefetch_thread = None + raise result self._kick_prefetch() return result @@ -507,9 +515,10 @@ def _do_one_prefetch(self): torch.cuda.set_device(self._cuda_device) try: self._prefetch_result = self._send_data_to_cp_tp_ranks() - except Exception: - # Process group may have been destroyed; signal stop. - self._prefetch_result = StopIteration() + except StopIteration as e: + self._prefetch_result = e + except Exception as e: + self._prefetch_result = e def close(self): """Stop the prefetch thread. Must be called before destroy_process_group().""" diff --git a/bionemo-recipes/recipes/esm2_native_te/collator.py b/bionemo-recipes/recipes/esm2_native_te/collator.py index e3cf6a9607..41ec600511 100644 --- a/bionemo-recipes/recipes/esm2_native_te/collator.py +++ b/bionemo-recipes/recipes/esm2_native_te/collator.py @@ -295,8 +295,12 @@ def __iter__(self): else: # Calculate how many padded tokens are already in the batch tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"])) - # Calculate how many tokens we can fit from this sample + # Calculate how many tokens we can fit from this sample, ensuring the + # padded length doesn't exceed the remaining capacity. tokens_available = self.max_tokens_per_batch - tokens_in_batch + if self.pad_sequences_to_be_divisible_by is not None: + d = self.pad_sequences_to_be_divisible_by + tokens_available = (tokens_available // d) * d first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) yield [*samples, first_part] samples = [remaining_part] @@ -488,11 +492,15 @@ def __iter__(self): @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") def __next__(self): """Get the batch from the dataloader for the current CP rank.""" - self._prefetch_thread.join() + if self._prefetch_thread is not None: + self._prefetch_thread.join() result = self._prefetch_result if isinstance(result, StopIteration): self._prefetch_thread = None raise result + if isinstance(result, Exception): + self._prefetch_thread = None + raise result self._kick_prefetch() return result @@ -507,9 +515,10 @@ def _do_one_prefetch(self): torch.cuda.set_device(self._cuda_device) try: self._prefetch_result = self._send_data_to_cp_tp_ranks() - except Exception: - # Process group may have been destroyed; signal stop. - self._prefetch_result = StopIteration() + except StopIteration as e: + self._prefetch_result = e + except Exception as e: + self._prefetch_result = e def close(self): """Stop the prefetch thread. Must be called before destroy_process_group().""" diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index c19e24e147..30487fb309 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -127,7 +127,11 @@ def log_step( metrics = self.metrics.compute() self.metrics.reset() - metrics["train/global_step"] = torch.tensor(step, dtype=torch.int64) + metrics = { + k: v.detach().cpu().item() if isinstance(v, torch.Tensor) and v.dim() == 0 else v + for k, v in metrics.items() + } + metrics["train/global_step"] = step if self._dist_config.is_main_process(): wandb.log(metrics, step=step) diff --git a/bionemo-recipes/recipes/llama3_native_te/collator.py b/bionemo-recipes/recipes/llama3_native_te/collator.py index e3cf6a9607..41ec600511 100644 --- a/bionemo-recipes/recipes/llama3_native_te/collator.py +++ b/bionemo-recipes/recipes/llama3_native_te/collator.py @@ -295,8 +295,12 @@ def __iter__(self): else: # Calculate how many padded tokens are already in the batch tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"])) - # Calculate how many tokens we can fit from this sample + # Calculate how many tokens we can fit from this sample, ensuring the + # padded length doesn't exceed the remaining capacity. tokens_available = self.max_tokens_per_batch - tokens_in_batch + if self.pad_sequences_to_be_divisible_by is not None: + d = self.pad_sequences_to_be_divisible_by + tokens_available = (tokens_available // d) * d first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) yield [*samples, first_part] samples = [remaining_part] @@ -488,11 +492,15 @@ def __iter__(self): @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") def __next__(self): """Get the batch from the dataloader for the current CP rank.""" - self._prefetch_thread.join() + if self._prefetch_thread is not None: + self._prefetch_thread.join() result = self._prefetch_result if isinstance(result, StopIteration): self._prefetch_thread = None raise result + if isinstance(result, Exception): + self._prefetch_thread = None + raise result self._kick_prefetch() return result @@ -507,9 +515,10 @@ def _do_one_prefetch(self): torch.cuda.set_device(self._cuda_device) try: self._prefetch_result = self._send_data_to_cp_tp_ranks() - except Exception: - # Process group may have been destroyed; signal stop. - self._prefetch_result = StopIteration() + except StopIteration as e: + self._prefetch_result = e + except Exception as e: + self._prefetch_result = e def close(self): """Stop the prefetch thread. Must be called before destroy_process_group().""" diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 57f2e8bdfd..3d6f63f256 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -147,10 +147,6 @@ def log_step( if isinstance(grad_norm, DTensor): grad_norm = grad_norm.to_local() - now = time.perf_counter() - step_time = now - self.previous_step_time - self.previous_step_time = now - if self._profiler is not None: self._profiler.step(step) @@ -162,6 +158,11 @@ def log_step( avg_loss = self.running_loss / self.grad_acc_step_count self.min_loss = torch.minimum(self.min_loss, avg_loss) + # Calculate an average step time over all steps in the logging window + now = time.perf_counter() + step_time = (now - self.previous_step_time) / self.logging_frequency + self.previous_step_time = now + # For some reason, these trigger a CudaStreamSynchronize call, which blocks the dataloader in the next # step. We therefore only update these once every logging_frequency steps. self.metrics["train/loss"].update(avg_loss) @@ -178,7 +179,11 @@ def log_step( metrics = self.metrics.compute() self.metrics.reset() - metrics["train/global_step"] = torch.tensor(step, dtype=torch.int64) + metrics = { + k: v.detach().cpu().item() if isinstance(v, torch.Tensor) and v.dim() == 0 else v + for k, v in metrics.items() + } + metrics["train/global_step"] = step if self._dist_config.is_main_process(): wandb.log(metrics, step=step) From 806c5692768659de2d8d66782748375c32d63937 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 11 Feb 2026 10:29:00 -0800 Subject: [PATCH 4/9] fix failing tests Signed-off-by: Peter St. John --- .../recipes/esm2_native_te/perf_logger.py | 51 +++++++++++-------- .../tests/test_perf_logger.py | 2 +- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py index 30487fb309..f7a71b3e6e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/esm2_native_te/perf_logger.py @@ -22,6 +22,7 @@ import torchmetrics.text import wandb from omegaconf import DictConfig, OmegaConf +from torch.distributed.tensor import DTensor from tqdm import tqdm from transformers.modeling_outputs import MaskedLMOutput @@ -83,7 +84,7 @@ def log_step( step: int, batch: dict[str, torch.Tensor], outputs: MaskedLMOutput, - grad_norm: torch.Tensor, + grad_norm: torch.Tensor | DTensor, lr: float, ): """Log a step to the logger and wandb. @@ -96,31 +97,39 @@ def log_step( lr: The learning rate of the step. """ with torch.no_grad(): - num_tokens = batch["input_ids"].numel() - # 1 is the padding token for ESM-2. - num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != 1].numel() - - self.min_loss = torch.minimum(self.min_loss, outputs.loss) - step_time, self.previous_step_time = time.perf_counter() - self.previous_step_time, time.perf_counter() - - self.metrics["train/loss"].update(outputs.loss) - self.metrics["train/learning_rate"].update(lr) - self.metrics["train/grad_norm"].update(grad_norm) - self.metrics["train/step_time"].update(step_time) - self.metrics["train/tokens_per_second_per_gpu"].update(num_tokens / step_time) - self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time) - self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens / self.logging_frequency) - - # Handle sequence packing for torchmetrics calculation. - if outputs.logits.dim() < 3: - outputs.logits = outputs.logits.unsqueeze(0) - - self.metrics["train/perplexity"].update(outputs.logits, batch["labels"]) + # FSDP2's clip_grad_norm_ returns a DTensor; convert to local tensor for torchmetrics compatibility. + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.to_local() if self.fp8_stats_enabled: debug_api.step() if step % self.logging_frequency == 0 and step > 0: + num_tokens = batch["input_ids"].numel() + # 1 is the padding token for ESM-2. + num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != 1].numel() + + self.min_loss = torch.minimum(self.min_loss, outputs.loss) + elapsed_time, self.previous_step_time = ( + time.perf_counter() - self.previous_step_time, + time.perf_counter(), + ) + step_time = elapsed_time / self.logging_frequency + + self.metrics["train/loss"].update(outputs.loss) + self.metrics["train/learning_rate"].update(lr) + self.metrics["train/grad_norm"].update(grad_norm) + self.metrics["train/step_time"].update(step_time) + self.metrics["train/tokens_per_second_per_gpu"].update(num_tokens / step_time) + self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time) + self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens) + + # Handle sequence packing for torchmetrics calculation. + if outputs.logits.dim() < 3: + outputs.logits = outputs.logits.unsqueeze(0) + + self.metrics["train/perplexity"].update(outputs.logits, batch["labels"]) + memory_allocated = torch.cuda.memory_allocated() / (1024**3) self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated) self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index 19af95ccb1..370e174c6f 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -95,7 +95,7 @@ def _run_steps(perf_logger, losses, grad_acc_steps=1): def _get_logged_losses(mock_wandb): """Extract reported loss values from wandb.log calls.""" - return [call[0][0]["train/loss"].item() for call in mock_wandb.log.call_args_list] + return [call[0][0]["train/loss"] for call in mock_wandb.log.call_args_list] class TestPerfLoggerLoss: From e424e406008722d0f167fcef35d4af0dbf631cca Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 11 Feb 2026 11:32:05 -0800 Subject: [PATCH 5/9] add collator docstring Signed-off-by: Peter St. John --- bionemo-recipes/models/esm2/src/esm/collator.py | 7 ++++++- bionemo-recipes/models/llama3/collator.py | 7 ++++++- bionemo-recipes/recipes/esm2_native_te/collator.py | 7 ++++++- bionemo-recipes/recipes/llama3_native_te/collator.py | 7 ++++++- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/bionemo-recipes/models/esm2/src/esm/collator.py b/bionemo-recipes/models/esm2/src/esm/collator.py index 41ec600511..b215a96e8c 100644 --- a/bionemo-recipes/models/esm2/src/esm/collator.py +++ b/bionemo-recipes/models/esm2/src/esm/collator.py @@ -510,7 +510,12 @@ def _kick_prefetch(self): self._prefetch_thread.start() def _do_one_prefetch(self): - """Fetch one batch in the background. Stores result in _prefetch_result.""" + """Fetch one batch in the background. + + This function calls the _send_data_to_cp_tp_ranks function to materialize the next batches for all ranks in the + given CP/TP group, and uses torch.distributed.scatter_object_list to scatter these batches to their + corresponding ranks. The result is stored in _prefetch_result, and returned when __next__ is called. + """ if self._cuda_device is not None: torch.cuda.set_device(self._cuda_device) try: diff --git a/bionemo-recipes/models/llama3/collator.py b/bionemo-recipes/models/llama3/collator.py index 41ec600511..b215a96e8c 100644 --- a/bionemo-recipes/models/llama3/collator.py +++ b/bionemo-recipes/models/llama3/collator.py @@ -510,7 +510,12 @@ def _kick_prefetch(self): self._prefetch_thread.start() def _do_one_prefetch(self): - """Fetch one batch in the background. Stores result in _prefetch_result.""" + """Fetch one batch in the background. + + This function calls the _send_data_to_cp_tp_ranks function to materialize the next batches for all ranks in the + given CP/TP group, and uses torch.distributed.scatter_object_list to scatter these batches to their + corresponding ranks. The result is stored in _prefetch_result, and returned when __next__ is called. + """ if self._cuda_device is not None: torch.cuda.set_device(self._cuda_device) try: diff --git a/bionemo-recipes/recipes/esm2_native_te/collator.py b/bionemo-recipes/recipes/esm2_native_te/collator.py index 41ec600511..b215a96e8c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/collator.py +++ b/bionemo-recipes/recipes/esm2_native_te/collator.py @@ -510,7 +510,12 @@ def _kick_prefetch(self): self._prefetch_thread.start() def _do_one_prefetch(self): - """Fetch one batch in the background. Stores result in _prefetch_result.""" + """Fetch one batch in the background. + + This function calls the _send_data_to_cp_tp_ranks function to materialize the next batches for all ranks in the + given CP/TP group, and uses torch.distributed.scatter_object_list to scatter these batches to their + corresponding ranks. The result is stored in _prefetch_result, and returned when __next__ is called. + """ if self._cuda_device is not None: torch.cuda.set_device(self._cuda_device) try: diff --git a/bionemo-recipes/recipes/llama3_native_te/collator.py b/bionemo-recipes/recipes/llama3_native_te/collator.py index 41ec600511..b215a96e8c 100644 --- a/bionemo-recipes/recipes/llama3_native_te/collator.py +++ b/bionemo-recipes/recipes/llama3_native_te/collator.py @@ -510,7 +510,12 @@ def _kick_prefetch(self): self._prefetch_thread.start() def _do_one_prefetch(self): - """Fetch one batch in the background. Stores result in _prefetch_result.""" + """Fetch one batch in the background. + + This function calls the _send_data_to_cp_tp_ranks function to materialize the next batches for all ranks in the + given CP/TP group, and uses torch.distributed.scatter_object_list to scatter these batches to their + corresponding ranks. The result is stored in _prefetch_result, and returned when __next__ is called. + """ if self._cuda_device is not None: torch.cuda.set_device(self._cuda_device) try: From 18db577177795cc3e7794c562020ba41fd588198 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 11 Feb 2026 09:17:09 -0800 Subject: [PATCH 6/9] starting cp benchmarking Signed-off-by: Peter St. John --- .../recipes/llama3_native_te/dataset.py | 73 +++++++++++++++++++ .../hydra_config/L2_cp_benchmark.yaml | 41 +++++++++++ .../hydra_config/defaults.yaml | 1 + .../llama3_native_te/train_fsdp2_cp.py | 8 +- 4 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml diff --git a/bionemo-recipes/recipes/llama3_native_te/dataset.py b/bionemo-recipes/recipes/llama3_native_te/dataset.py index 6c0b47cf68..e566fff02d 100644 --- a/bionemo-recipes/recipes/llama3_native_te/dataset.py +++ b/bionemo-recipes/recipes/llama3_native_te/dataset.py @@ -17,6 +17,7 @@ import datasets import datasets.distributed +import torch from torch.utils.data import DataLoader, DistributedSampler from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer @@ -306,3 +307,75 @@ def create_thd_dataloader( ) return train_dataloader, tokenized_dataset + + +class MockTokenDataset(torch.utils.data.Dataset): + """Dataset that generates random token sequences for benchmarking. + + All sequences have the same fixed length, so no padding is needed. + + Args: + vocab_size: Vocabulary size for random token generation. + seq_length: Length of each generated sequence. + num_samples: Total number of samples in the dataset. + """ + + def __init__(self, vocab_size: int, seq_length: int, num_samples: int): + """Initialize the mock dataset.""" + self.vocab_size = vocab_size + self.seq_length = seq_length + self.num_samples = num_samples + + def __len__(self): + """Return the number of samples.""" + return self.num_samples + + def __getitem__(self, idx): + """Return a random token sequence.""" + input_ids = torch.randint(0, self.vocab_size, (self.seq_length,)) + return {"input_ids": input_ids} + + +def _mock_collator(features: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + """Collator for MockTokenDataset that stacks fixed-length sequences into a batch.""" + input_ids = torch.stack([f["input_ids"] for f in features]) + return {"input_ids": input_ids, "labels": input_ids.clone(), "attention_mask": torch.ones_like(input_ids)} + + +def create_mock_dataloader( + distributed_config: DistributedConfig, + micro_batch_size: int, + max_seq_length: int, + vocab_size: int = 128256, + num_samples: int = 100_000, + **kwargs, +): + """Create a mock dataloader with random tokens for benchmarking. + + Args: + distributed_config: The distributed configuration. + micro_batch_size: The batch size per device. + max_seq_length: The sequence length of each generated sample. + vocab_size: Vocabulary size for random token generation. Defaults to Llama 3 vocab size. + num_samples: Total number of samples in the dataset. + **kwargs: Ignored extra arguments for compatibility with other dataloader configs. + + Returns: + A tuple of (dataloader, sampler). + """ + dataset = MockTokenDataset(vocab_size, max_seq_length, num_samples) + sampler = DistributedSampler( + dataset, + rank=distributed_config.rank, + num_replicas=distributed_config.world_size, + seed=42, + ) + train_dataloader = DataLoader( + dataset, + batch_size=micro_batch_size, + sampler=sampler, + collate_fn=_mock_collator, + num_workers=0, + pin_memory=True, + ) + return train_dataloader, sampler diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml new file mode 100644 index 0000000000..addb1d2bcc --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml @@ -0,0 +1,41 @@ +defaults: + - defaults + - _self_ + +config_name_or_path: ./model_configs/meta-llama/Llama-3.2-1B + +config_kwargs: + attn_input_format: "bshd" + self_attn_mask_type: "causal" + +cp_size: 1 + +use_mock_dataset: true +use_sequence_packing: false +use_meta_device: true +use_torch_compile: false + +num_train_steps: 100 + +dataset: + tokenizer_name_or_path: null # Not needed for mock dataset + micro_batch_size: 1 + max_seq_length: 8192 + num_samples: 100_000 + load_dataset_kwargs: null # Not needed for mock dataset + +wandb: + name: "llama3-cp-benchmark" + mode: "offline" + +lr_scheduler_kwargs: + num_warmup_steps: 10 + num_decay_steps: 90 + +checkpoint: + ckpt_dir: null + save_final_model: false + resume_from_checkpoint: false + +logger: + frequency: 1 diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index d6c181598f..784558b0b0 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -13,6 +13,7 @@ use_meta_device: true use_torch_compile: false use_sequence_packing: false +use_mock_dataset: false dataset: tokenizer_name_or_path: ??? # Set to the path of your tokenizer (e.g., meta-llama/Llama-3.1-8B or ./tokenizers/nucleotide_fast_tokenizer) diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index 742e21a63d..5113561948 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -30,7 +30,7 @@ from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel -from dataset import create_bshd_dataloader, create_thd_dataloader +from dataset import create_bshd_dataloader, create_mock_dataloader, create_thd_dataloader from distributed_config import DistributedConfig from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM from perf_logger import PerfLogger @@ -119,7 +119,11 @@ def main(args: DictConfig) -> float | None: logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2") OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2) if device_mesh["cp"].get_local_rank() == 0: - if args.use_sequence_packing: + if args.use_mock_dataset: + train_dataloader, dataset_or_sampler = create_mock_dataloader( + dist_config, vocab_size=config.vocab_size, **args.dataset + ) + elif args.use_sequence_packing: train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) else: train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) From 9fb7b9377ccacd5010d59653d109252b18d7753d Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 11 Feb 2026 10:37:00 -0800 Subject: [PATCH 7/9] add llama3.1 8b config Signed-off-by: Peter St. John --- .../hydra_config/L2_cp_benchmark.yaml | 2 +- .../meta-llama/Llama-3.1-8B/config.json | 34 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml index addb1d2bcc..aa0fd373e9 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml @@ -2,7 +2,7 @@ defaults: - defaults - _self_ -config_name_or_path: ./model_configs/meta-llama/Llama-3.2-1B +config_name_or_path: ./model_configs/meta-llama/Llama-3.1-8B config_kwargs: attn_input_format: "bshd" diff --git a/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json new file mode 100644 index 0000000000..0e235cd7de --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json @@ -0,0 +1,34 @@ +{ + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 0.00001, + "rope_scaling": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.43.0.dev0", + "use_cache": true, + "vocab_size": 128256 +} From 84aba4671ea156f4d467dfb1237e91a18c860ed8 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 11 Feb 2026 10:50:57 -0800 Subject: [PATCH 8/9] testing on b300 Signed-off-by: Peter St. John --- .../hydra_config/L2_cp_benchmark.yaml | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml index aa0fd373e9..4b48297b9c 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_cp_benchmark.yaml @@ -26,7 +26,8 @@ dataset: wandb: name: "llama3-cp-benchmark" - mode: "offline" + project: "bionemo-recipes-pstjohn" + mode: "online" lr_scheduler_kwargs: num_warmup_steps: 10 @@ -38,4 +39,12 @@ checkpoint: resume_from_checkpoint: false logger: - frequency: 1 + frequency: 10 + +fp8_config: + enabled: false + fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + enabled: false From 40890021a4dacd63672e18990de99013e59d4f2e Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Fri, 20 Feb 2026 13:40:43 -0800 Subject: [PATCH 9/9] add train_ddp_cp Signed-off-by: Peter St. John --- .../hydra_config/L0_sanity_ddp_cp.yaml | 12 + .../recipes/llama3_native_te/perf_logger.py | 16 +- .../recipes/llama3_native_te/slurm_nvl72.sh | 49 ++++ .../recipes/llama3_native_te/train_ddp_cp.py | 231 ++++++++++++++++++ 4 files changed, 299 insertions(+), 9 deletions(-) create mode 100644 bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_ddp_cp.yaml create mode 100644 bionemo-recipes/recipes/llama3_native_te/slurm_nvl72.sh create mode 100644 bionemo-recipes/recipes/llama3_native_te/train_ddp_cp.py diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_ddp_cp.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_ddp_cp.yaml new file mode 100644 index 0000000000..1f4f52f330 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L0_sanity_ddp_cp.yaml @@ -0,0 +1,12 @@ +defaults: + - L0_sanity + - _self_ + +cp_size: 1 + +use_mock_dataset: true +use_sequence_packing: false + +config_kwargs: + attn_input_format: "bshd" + self_attn_mask_type: "causal" diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 3d6f63f256..2a1cbead7c 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -77,12 +77,13 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): self._wandb_run = wandb.init(**args.wandb, config=self._run_config) self._progress_bar = tqdm(total=args.num_train_steps, desc="Training") - if args.profiler.enabled: - self._profiler = NsightProfiler( - **args.profiler, - wandb_run=self._wandb_run, - dist_config=dist_config, - ) + # Create profiler on all ranks so every GPU calls cudaProfilerStart/Stop, + # ensuring nsys captures the correct step range on every node. + if args.profiler.enabled: + self._profiler = NsightProfiler( + **args.profiler, + dist_config=dist_config, + ) # Gradient accumulation tracking self.num_tokens = 0 @@ -222,7 +223,6 @@ class NsightProfiler: enabled: Whether profiling is enabled. start_step: The step number at which to start profiling. end_step: The step number at which to end profiling. - wandb_run: The wandb run for logging artifacts. dist_config: The distributed configuration. Attributes: @@ -238,11 +238,9 @@ def __init__( enabled: bool, start_step: int, end_step: int, - wandb_run: wandb.Run, dist_config: DistributedConfig, ): """Initialize the Nsight profiler.""" - self._wandb_run = wandb_run self._dist_config = dist_config self.start_step = start_step diff --git a/bionemo-recipes/recipes/llama3_native_te/slurm_nvl72.sh b/bionemo-recipes/recipes/llama3_native_te/slurm_nvl72.sh new file mode 100644 index 0000000000..84ae9ca5ee --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/slurm_nvl72.sh @@ -0,0 +1,49 @@ +#!/bin/bash +#SBATCH --nodes=9 # number of nodes +#SBATCH --segment=9 +#SBATCH --ntasks-per-node=1 # one task per node; torchrun handles per-GPU processes +#SBATCH --time=00:20:00 # wall time +#SBATCH --mem=0 # all mem avail +#SBATCH --account=healthcareeng_bionemo # account +#SBATCH --partition=gb300 # partition +#SBATCH --mail-type=FAIL # only send email on failure +#SBATCH --exclusive # exclusive node access +#SBATCH --output=job_output/slurm_%x.%j.out +#SBATCH --job-name=healthcareeng_bionemo-recipes.llama3-cp-benchmark + +set -x -e +ulimit -c 0 + +# Usage: +# sbatch slurm_nvl72.sh # defaults to cp=6 +# sbatch slurm_nvl72.sh 9 # cp_size=9, seq_len=73728 (9*8192) + +# Accept cp size as first positional argument (default: 6) +CP_SIZE=${1:-6} +MAX_SEQ_LENGTH=$((CP_SIZE * 8192)) +SEQ_LENGTH_K=$((MAX_SEQ_LENGTH / 1000))K + +export CMD="TRITON_CACHE_DIR=/tmp/triton_cache \ + HF_TOKEN=hf_aK... \ + NVTE_BATCH_MHA_P2P_COMM=1 \ + torchrun \ + --rdzv_id \$SLURM_JOB_ID \ + --rdzv_backend c10d \ + --rdzv_endpoint \$MASTER_ADDR:\$MASTER_PORT \ + --nproc-per-node 4 \ + --nnodes \$SLURM_NNODES \ + --node-rank \$SLURM_NODEID \ + train_fsdp2_cp.py \ + --config-name L2_cp_benchmark \ + wandb.name=llama3-cp-benchmark-lyris-32gpu-cp${CP_SIZE}-70B-${SEQ_LENGTH_K} \ + wandb.project=bionemo-recipes-pstjohn \ + wandb.mode=online \ + dataset.max_seq_length=${MAX_SEQ_LENGTH} \ + cp_size=${CP_SIZE} \ + config_name_or_path=meta-llama/Llama-3.1-70B +" + +srun \ + --container-image=/lustre/fsw/healthcareeng_bionemo/pstjohn/enroot/nvidian+cvai_bnmo_trng+bionemo+llama3_cp_arm_0211.sqsh \ + --container-mounts=$HOME/.netrc:/root/.netrc,/lustre/fsw/healthcareeng_bionemo/pstjohn/cache:/root/.cache \ + bash -c "$CMD" diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp_cp.py new file mode 100644 index 0000000000..19fea1bc97 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp_cp.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +from contextlib import nullcontext +from pathlib import Path + +import hydra +import torch +import transformer_engine.pytorch +from omegaconf import DictConfig, OmegaConf +from torch.distributed.device_mesh import init_device_mesh +from torch.optim import AdamW +from transformer_engine.common.recipe import Format + +from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint +from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel +from dataset import create_bshd_dataloader, create_mock_dataloader, create_thd_dataloader +from distributed_config import DistributedConfig +from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM +from perf_logger import PerfLogger +from scheduler import get_cosine_annealing_schedule_with_warmup + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity_ddp_cp", version_base="1.2") +def main(args: DictConfig) -> float | None: + """Train Llama3 with TE layers using DDP + Context Parallelism for benchmarking. + + Returns: + float: The loss value for the final batch. + """ + # Initialize the distributed configuration, including creating the distributed process group. + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + # Create a 2D device mesh for DDP + CP. + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(dist_config.world_size // args.cp_size, args.cp_size), + mesh_dim_names=("dp", "cp"), + ) + logger.info(f"Created device mesh: {device_mesh}") + + # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + + # Create an empty Llama3 model with a causal language model head. + config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + + with transformer_engine.pytorch.quantized_model_init( + recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs + ): + model = NVLlamaForCausalLM(config) + + logger.info("Initialized Model:\n%s", model) + + model = model.to(device=device) + + # Attach the CP group to the model layers before wrapping with DDP. + for layer in model.model.layers: + layer.set_context_parallel_group( + device_mesh["cp"].get_group(), + torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), + torch.cuda.Stream(), + ) + + # Wrap with DDP for data parallelism across the dp mesh dimension. + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + device_mesh=device_mesh["dp"], + ) + + # Create optimizer. + optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) + scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + if args.use_torch_compile: + model = torch.compile(model) + + # Create the context-aware dataloader. We only create the dataloader on cp rank 0 and wrap it in a + # ContextParallelDataLoaderWrapper that will shard and distribute the data across the context parallelism group. + if args.dataset.get("pad_sequences_to_be_divisible_by", None) is None: + logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2") + OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2) + + if device_mesh["cp"].get_local_rank() == 0: + if args.use_mock_dataset: + train_dataloader, dataset_or_sampler = create_mock_dataloader( + dist_config, vocab_size=config.vocab_size, **args.dataset + ) + elif args.use_sequence_packing: + train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) + else: + train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) + + train_dataloader.collate_fn = DataCollatorForContextParallel( + collator=train_dataloader.collate_fn, + device_mesh=device_mesh, + qkv_format=args.config_kwargs.attn_input_format, + is_causal_lm=True, + ) + else: + train_dataloader = None + dataset_or_sampler = None + + train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"]) + + # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0. + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp_cp" if args.checkpoint.ckpt_dir else None + if args.checkpoint.resume_from_checkpoint and ckpt_path: + model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_ddp( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + dist_config=dist_config, + dataloader=train_dataloader, + ) + else: + start_step = 0 + epoch = 0 + + perf_logger = PerfLogger(dist_config, args) + + gc.collect() + torch.cuda.empty_cache() + + # Training loop + step = start_step + micro_step = 0 # Gradient accumulation step counter + while step < args.num_train_steps: + for batch in train_dataloader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 + + micro_step += 1 + # Use no_sync to prevent gradient synchronization until the last microbatch + with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext(): + # Forward pass with mixed precision. + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + outputs = model(**batch) + + # Backward pass - scale loss by grad_acc_steps for proper gradient averaging + loss = outputs.loss / args.grad_acc_steps + loss.backward() + + # Log microbatch step data for accumulation metrics + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) + + # Gradient accumulation - only step optimizer after accumulating gradients + if micro_step % args.grad_acc_steps == 0: + micro_step = 0 + + # Compute and clip gradient norms. + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() + + # Step optimizer. + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + perf_logger.log_step( + step=step, + grad_norm=total_norm, + lr=optimizer.param_groups[0]["lr"], + ) + + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): + save_checkpoint_ddp( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + step=step, + epoch=epoch, + dist_config=dist_config, + dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, + max_checkpoints=args.checkpoint.max_checkpoints, + ) + + step += 1 + if step >= args.num_train_steps: + break + + # Dataloader exhausted, incrementing epoch + epoch += 1 + if dataset_or_sampler is not None: + dataset_or_sampler.set_epoch(epoch) + + # Save final model to a .safetensors file. + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_ddp( + model=model, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + # Clean up distributed training. Must close the dataloader wrapper before destroying the process group + # to avoid hangs from in-flight prefetch threads. + train_dataloader.close() + perf_logger.finish() + torch.distributed.destroy_process_group() + + return perf_logger.min_loss + + +if __name__ == "__main__": + main()