diff --git a/bionemo-recipes/models/esm2/src/esm/collator.py b/bionemo-recipes/models/esm2/src/esm/collator.py index 63b4716410..dcf998c797 100644 --- a/bionemo-recipes/models/esm2/src/esm/collator.py +++ b/bionemo-recipes/models/esm2/src/esm/collator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Data collator for THD input format tests. +"""Data collators for sequence packing and context parallel training. This should eventually get moved to a separate package, or possibly upstreamed into `transformers`. """ @@ -674,6 +674,16 @@ def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tupl def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False): + """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths. + + Args: + features: List of tokenized samples, each containing at least ``input_ids``. + return_position_ids: Whether to return position ids for each token. + + Returns: + A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and + ``max_length_q``/``max_length_k``. + """ is_labels_provided = "labels" in features[0] sample_lengths = [len(sample["input_ids"]) for sample in features] @@ -920,7 +930,7 @@ def process_tensor_bshd(val): class BatchType(TypedDict): - """The fields in the batch dictionary fo THD context parallel.""" + """The fields in the batch dictionary for THD context parallel.""" input_ids: torch.Tensor labels: torch.Tensor | None diff --git a/bionemo-recipes/models/llama3/README.md b/bionemo-recipes/models/llama3/README.md index 758de85efb..a5fa2efbaf 100644 --- a/bionemo-recipes/models/llama3/README.md +++ b/bionemo-recipes/models/llama3/README.md @@ -1,6 +1,166 @@ -# 🚧 Llama-3.1 Optimized with NVIDIA TransformerEngine +# Llama-3.1 Optimized with NVIDIA TransformerEngine -This folder contains source code and tests for an Llama-3.1 model that inherits from the transformers `PreTrainedModel` -class and uses TransformerEngine layers. +This folder contains source code and tests for Llama-3.\* style models that inherit from the transformers +`PreTrainedModel` class and uses TransformerEngine layers. Unlike the ESM-2 model, we do not currently distribute +pre-converted TE checkpoints on HuggingFace Hub. Instead, users can convert existing Llama 3 checkpoints from +HuggingFace using the provided conversion utilities. -This folder is currently work in progress and is not yet ready for general use. +## Feature support + +The Llama-3 implementation natively supports the following TransformerEngine-provided optimizations: + +| Feature | Support | +| --------------------------------------- | -------------------------------------------------------------------------------- | +| **FP8** | ✅ Supported on compute capacity 9.0 and above (Hopper+) | +| **MXFP8** | ✅ Supported on compute capacity 10.0 and 10.3 (Blackwell), 12.0 support pending | +| **Sequence Packing / THD input format** | ✅ Supported | +| **FP8 with THD input format** | ✅ Supported where FP8 is supported | +| **Import from HuggingFace checkpoints** | ✅ Supported | +| **Export to HuggingFace checkpoints** | ✅ Supported | +| **KV-cache inference** | ✅ Supported (including beam search) | +| **Context Parallelism** | ✅ Supported | +| **Tensor Parallelism** | 🚧 Under development | + +Refer to [BioNeMo Recipes](../../recipes/llama3_native_te/README.md) for more details on how to use these features to accelerate model +training and inference with native PyTorch training loops. + +## Inference Examples + +### Quick start: convert and run + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from convert import convert_llama_hf_to_te + +# Load the original HuggingFace Llama 3 model +model_hf = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct", dtype=torch.bfloat16 +) + +# Convert to TransformerEngine. +model_te = convert_llama_hf_to_te(model_hf) +model_te.to("cuda") + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") +tokenizer.pad_token = tokenizer.eos_token + +inputs = tokenizer("The quick brown fox", return_tensors="pt") +inputs = {k: v.to("cuda") for k, v in inputs.items()} + +with torch.no_grad(): + output_ids = model_te.generate(**inputs, max_new_tokens=16, use_cache=False) + +print(tokenizer.decode(output_ids[0], skip_special_tokens=True)) +``` + +### Inference with KV-cache + +For efficient autoregressive generation, use the TE-provided `InferenceParams` KV-cache: + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformer_engine.pytorch.attention import InferenceParams + +from convert import convert_llama_hf_to_te + +model_hf = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct", torch_dtype=torch.bfloat16 +) +model_te = convert_llama_hf_to_te( + model_hf, attn_input_format="thd", self_attn_mask_type="padding_causal" +) +model_te.to("cuda") + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") + +inputs = tokenizer("The quick brown fox", return_tensors="pt") +inputs = {k: v.to("cuda") for k, v in inputs.items()} + +# Allocate KV-cache +past_key_values = InferenceParams( + max_batch_size=1, + max_sequence_length=256, + num_heads_kv=model_te.config.num_key_value_heads, + head_dim_k=model_te.config.hidden_size // model_te.config.num_attention_heads, + dtype=torch.bfloat16, + qkv_format="thd", + max_ctx_len=256, +) + +for layer_number in range(1, model_te.config.num_hidden_layers + 1): + past_key_values.allocate_memory(layer_number) + +with torch.no_grad(): + output_ids = model_te.generate( + **inputs, + max_new_tokens=16, + use_cache=True, + past_key_values=past_key_values, + ) + +print(tokenizer.decode(output_ids[0], skip_special_tokens=True)) +``` + +## Recipe Links + +Training recipes are available in the `bionemo-recipes/recipes/` directory: + +- **[llama3_native_te](../../recipes/llama3_native_te/)** - Demonstrates training with a native PyTorch training loop + using FSDP2, including FP8, sequence packing, and context parallelism. + +## Converting Between Model Formats + +This section explains how to convert between Hugging Face Transformers and Transformer Engine (TE) Llama 3 model +formats. The process demonstrates bidirectional conversion: from Transformers to TE format for optimized training and +inference, and back to Hugging Face Transformers format for sharing and deployment. + +### Converting from HF Transformers to TE + +```python +from transformers import AutoModelForCausalLM + +from convert import convert_llama_hf_to_te + +model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") +model_te = convert_llama_hf_to_te(model_hf) +model_te.save_pretrained("/path/to/te_checkpoint") +``` + +### Converting from TE back to HF Transformers + +```python +from convert import convert_llama_te_to_hf +from modeling_llama_te import NVLlamaForCausalLM + +model_te = NVLlamaForCausalLM.from_pretrained("/path/to/te_checkpoint") +model_hf = convert_llama_te_to_hf(model_te) +model_hf.save_pretrained("/path/to/hf_checkpoint") +``` + +Once converted back to HF format, the model can be loaded by any library that supports Llama 3, such as +[vLLM](https://github.com/vllm-project/vllm) or [SGLang](https://github.com/sgl-project/sglang). + +### Validating Converted Models + +To validate the converted models, refer to the commands in [Inference Examples](#inference-examples) above to load and +test both the original and converted models to ensure loss and logit values are similar. Additionally, refer to the +golden value tests in [test_modeling_llama_te.py](tests/test_modeling_llama_te.py). + +## Developer Guide + +### Running tests + +To run tests locally, run `recipes_local_test.py` from the repository root with the model directory as an argument. + +```bash +./ci/scripts/recipes_local_test.py bionemo-recipes/models/llama3/ +``` + +### Development container + +To use the provided devcontainer, use "Dev Containers: Reopen in Container" from the VSCode menu, and choose the +"BioNeMo Recipes Dev Container" option. To run the tests inside the container, first install the model package in +editable mode with `pip install -e .`, then run `pytest -v .` in the model directory. diff --git a/bionemo-recipes/models/llama3/collator.py b/bionemo-recipes/models/llama3/collator.py index 63b4716410..dcf998c797 100644 --- a/bionemo-recipes/models/llama3/collator.py +++ b/bionemo-recipes/models/llama3/collator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Data collator for THD input format tests. +"""Data collators for sequence packing and context parallel training. This should eventually get moved to a separate package, or possibly upstreamed into `transformers`. """ @@ -674,6 +674,16 @@ def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tupl def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False): + """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths. + + Args: + features: List of tokenized samples, each containing at least ``input_ids``. + return_position_ids: Whether to return position ids for each token. + + Returns: + A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and + ``max_length_q``/``max_length_k``. + """ is_labels_provided = "labels" in features[0] sample_lengths = [len(sample["input_ids"]) for sample in features] @@ -920,7 +930,7 @@ def process_tensor_bshd(val): class BatchType(TypedDict): - """The fields in the batch dictionary fo THD context parallel.""" + """The fields in the batch dictionary for THD context parallel.""" input_ids: torch.Tensor labels: torch.Tensor | None diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py index d4821be1f7..433edd084a 100644 --- a/bionemo-recipes/models/llama3/modeling_llama_te.py +++ b/bionemo-recipes/models/llama3/modeling_llama_te.py @@ -62,7 +62,7 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embed_tokens layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard # deviation. self.model.embed_tokens.to_empty(device="cuda") @@ -363,9 +363,10 @@ def forward( ) -class NVLlamaForSequenceClassification( # noqa: D101 +class NVLlamaForSequenceClassification( transformers.modeling_layers.GenericForSequenceClassification, NVLlamaPreTrainedModel -): ... +): + """Llama3 model with sequence classification head.""" class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestionAnswering, NVLlamaPreTrainedModel): @@ -374,9 +375,10 @@ class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestio base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` -class NVLlamaForTokenClassification( # noqa: D101 +class NVLlamaForTokenClassification( transformers.modeling_layers.GenericForTokenClassification, NVLlamaPreTrainedModel -): ... +): + """Llama3 model with token classification head.""" torch._dynamo.config.capture_scalar_outputs = True diff --git a/bionemo-recipes/recipes/esm2_native_te/collator.py b/bionemo-recipes/recipes/esm2_native_te/collator.py index 63b4716410..dcf998c797 100644 --- a/bionemo-recipes/recipes/esm2_native_te/collator.py +++ b/bionemo-recipes/recipes/esm2_native_te/collator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Data collator for THD input format tests. +"""Data collators for sequence packing and context parallel training. This should eventually get moved to a separate package, or possibly upstreamed into `transformers`. """ @@ -674,6 +674,16 @@ def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tupl def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False): + """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths. + + Args: + features: List of tokenized samples, each containing at least ``input_ids``. + return_position_ids: Whether to return position ids for each token. + + Returns: + A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and + ``max_length_q``/``max_length_k``. + """ is_labels_provided = "labels" in features[0] sample_lengths = [len(sample["input_ids"]) for sample in features] @@ -920,7 +930,7 @@ def process_tensor_bshd(val): class BatchType(TypedDict): - """The fields in the batch dictionary fo THD context parallel.""" + """The fields in the batch dictionary for THD context parallel.""" input_ids: torch.Tensor labels: torch.Tensor | None diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging.py b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging.py index 6cd452ebb8..d01024f04c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging.py +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-Apache2 # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/bionemo-recipes/recipes/esm2_peft_te/collator.py b/bionemo-recipes/recipes/esm2_peft_te/collator.py index 63b4716410..dcf998c797 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/collator.py +++ b/bionemo-recipes/recipes/esm2_peft_te/collator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Data collator for THD input format tests. +"""Data collators for sequence packing and context parallel training. This should eventually get moved to a separate package, or possibly upstreamed into `transformers`. """ @@ -674,6 +674,16 @@ def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tupl def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False): + """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths. + + Args: + features: List of tokenized samples, each containing at least ``input_ids``. + return_position_ids: Whether to return position ids for each token. + + Returns: + A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and + ``max_length_q``/``max_length_k``. + """ is_labels_provided = "labels" in features[0] sample_lengths = [len(sample["input_ids"]) for sample in features] @@ -920,7 +930,7 @@ def process_tensor_bshd(val): class BatchType(TypedDict): - """The fields in the batch dictionary fo THD context parallel.""" + """The fields in the batch dictionary for THD context parallel.""" input_ids: torch.Tensor labels: torch.Tensor | None diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md index b7416929cc..aa776026d0 100644 --- a/bionemo-recipes/recipes/llama3_native_te/README.md +++ b/bionemo-recipes/recipes/llama3_native_te/README.md @@ -16,9 +16,9 @@ bionemo-framework repository. You can download a zipped directory of this folder ## Supported Models and Training Features -| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | -| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | -| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | 🚧 | 🚧 | 🚧 | +| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | Tensor Parallelism | +| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ | +| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅: Supported
🚧: Under development
@@ -42,7 +42,8 @@ To run the container, run: docker run -it --gpus all --network host --ipc=host --rm -v ${PWD}:/workspace/bionemo llama3_native_te /bin/bash ``` -Alternatively, the dependencies can be installed manually in an environment with CUDA support. See `requirements.txt` for the list of dependencies. +Alternatively, the dependencies can be installed manually in an environment with CUDA support. See `requirements.txt` +for the list of dependencies. ### Performance Benchmarks @@ -70,10 +71,12 @@ Training was performed with BF16 precision. ### Distributed Training -This recipe supports distributed training using DDP and FSDP2, shown in two separate training entrypoints: +This recipe supports distributed training using DDP, FSDP2, FSDP2 with Context Parallelism, and Megatron-FSDP with Context Parallelism, shown in four separate training entrypoints: - [Distributed Data Parallel (DDP)](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html), shown in `train_ddp.py` - [Fully Sharded Data Parallel 2 (FSDP2)](https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html), shown in `train_fsdp2.py` +- FSDP2 with Context Parallelism, shown in `train_fsdp2_cp.py` +- Megatron-FSDP with Context Parallelism, shown in `train_mfsdp_cp.py` ## Commands to Launch Training @@ -91,6 +94,21 @@ torchrun --nproc_per_node=2 train_fsdp2.py # or train_ddp.py Multi-Node training is supported with both strategies. +A convergence test configuration (`L0_convergence`) is also available, which uses a tiny Llama model +to verify that the training loop can overfit on a small dataset: + +```bash +python train_fsdp2.py --config-name L0_convergence +``` + +Gradient accumulation is supported with both strategies. To enable gradient accumulation, set `grad_acc_steps` to the +number of steps to accumulate gradients before updating the model parameters. This is useful to scale the effective +batch size while running on a smaller number of GPUs. + +```bash +python train_fsdp2.py --config-name L0_sanity grad_acc_steps=2 +``` + ### FP8 Training To run training with FP8, enable it by overriding the `fp8_config.enabled=true` configuration parameter. Additional FP8 @@ -106,15 +124,15 @@ We also provide a mechanism to receive tensor data related to FP8 layers during To enable this please select the following config options. -```python +```bash python train_fsdp2.py \ -fp8_stats_config.enabled=True # whether to log stats or not -fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy # where to store the logs -fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml # specifies what stats you want to run. Currently this is saved in this yaml file. -fp8_config.enabled=True # set this to use FP8 otherwise stats logging wont work + fp8_stats_config.enabled=True \ + fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy \ + fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml \ + fp8_config.enabled=True ``` -Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. It is not yet available for `train_mfsdp`. +Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the [NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) in more detail. Below we will cover some very basic elements of the file structure. @@ -142,6 +160,30 @@ python train_fsdp2.py --config-name L0_sanity \ use_sequence_packing=true ``` +### Context Parallel Training + +Context parallelism splits each sequence across multiple GPUs along the sequence dimension, enabling training with very +long sequences. Use `train_fsdp2_cp.py` with the `L0_sanity_cp` configuration and set `cp_size` to the number of context +parallelism ranks. Works with both BSHD (no padding) and THD (padding) input formats. Only TE models are supported. + +```bash +torchrun --nproc_per_node=4 train_fsdp2_cp.py --config-name L0_sanity_cp cp_size=2 +``` + +### Megatron-FSDP with Context Parallelism + +Megatron-FSDP (`train_mfsdp_cp.py`) provides an alternative FSDP implementation from the `megatron-fsdp` package with +context parallelism support. It creates a 3D device mesh `(dp, cp, tp)` where `tp` is a dummy dimension of size 1. Only +TE models are supported. Note that `torch.compile` is not supported with Megatron-FSDP. + +```bash +# Single GPU (cp_size=1) +python train_mfsdp_cp.py --config-name L0_sanity_cp cp_size=1 + +# Multi-GPU with context parallelism +torchrun --nproc_per_node=4 train_mfsdp_cp.py --config-name L0_sanity_cp cp_size=2 +``` + ## Downloading Pre-Training Data For Offline Training This recipe is configured to use genomic sequences. The default configuration uses a local test file diff --git a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py index a5af170661..7d3263f2f2 100644 --- a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py +++ b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py @@ -40,6 +40,10 @@ logger = logging.getLogger(__name__) + +# Tracks in-flight async checkpoint futures keyed by strategy name (e.g. "fsdp2"). +# Each entry holds the Future returned by dcp_async_save so we can await it before starting +# the next async save or before shutting down. _ckpt_futures: dict = {} @@ -379,6 +383,147 @@ def save_final_model_fsdp2( logger.info(f"Saved final FSDP2 model to {save_directory} (weights + config only)") +# ============================================================================ +# mFSDP Checkpointing +# ============================================================================ + + +def load_checkpoint_mfsdp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, +) -> CheckpointOutput: + """Load mFSDP distributed checkpoint. + + Args: + model: The model to load. + optimizer: The optimizer to load. + scheduler: The LR scheduler to load. + ckpt_path: The directory containing checkpoints. + dist_config: The distributed configuration. + dataloader: The dataloader to load. + + Returns: + Tuple of (model, optimizer, scheduler, dataloader, step, epoch). + """ + checkpoint_path, step = get_latest_checkpoint(ckpt_path) + if not checkpoint_path: + logger.info("No mFSDP checkpoint found, starting from scratch") + return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0) + + ckpt_state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "metadata": { + "step": step, # Initialize with current step from filename + "epoch": 0, # Initialize with default epoch + }, + } + torch.distributed.checkpoint.load(state_dict=ckpt_state_dict, checkpoint_id=checkpoint_path) + + model.load_state_dict(ckpt_state_dict["model"], strict=False) + optimizer.load_state_dict(ckpt_state_dict["optimizer"]) + scheduler.load_state_dict(ckpt_state_dict["scheduler"]) + dataloader = load_dataloader(dataloader, checkpoint_path, dist_config) + + step = ckpt_state_dict["metadata"]["step"] + epoch = ckpt_state_dict["metadata"]["epoch"] + + # Ensure all ranks have completed loading before proceeding + torch.distributed.barrier() + + logger.info(f"Loaded mFSDP checkpoint from step {step}") + + # Increment the step by one to avoid re-running the previous step. + return CheckpointOutput(model, optimizer, scheduler, dataloader, step + 1, epoch) + + +def save_checkpoint_mfsdp( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + ckpt_path: str | os.PathLike, + step: int, + epoch: int, + dist_config: DistributedConfig, + dataloader: StatefulDataLoader | None = None, + max_checkpoints: int | None = None, +) -> None: + """Save mFSDP distributed checkpoint. + + Args: + model: The model to save. + optimizer: The optimizer to save. + scheduler: The LR scheduler to save. + ckpt_path: The directory to save the checkpoint. + step: The step number to save the checkpoint. + epoch: The epoch number to save the checkpoint. + dist_config: The distributed configuration. + dataloader: The dataloader to save. + max_checkpoints: The maximum number of checkpoints to keep. + """ + ckpt_path = Path(ckpt_path) + checkpoint_path = ckpt_path / f"step_{step}" + checkpoint_path.mkdir(parents=True, exist_ok=True) + + # Save dataloader state, if provided. + save_dataloader(dataloader, checkpoint_path, dist_config) + + # Save model, optimizer, scheduler state, and metadata + state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "metadata": { + "step": step, + "epoch": epoch, + }, + } + + torch.distributed.checkpoint.save(state_dict, checkpoint_id=checkpoint_path) + + if dist_config.is_main_process(): + logger.info(f"Saved mFSDP checkpoint to {checkpoint_path}") + + if max_checkpoints is not None and dist_config.is_main_process(): + prune_checkpoints(ckpt_path, max_checkpoints) + + +def save_final_model_mfsdp( + model: torch.nn.Module, + save_directory: str | os.PathLike, + dist_config: DistributedConfig, +) -> None: + """Save final model for mFSDP - requires parameter gathering on all ranks.""" + from megatron_fsdp.uneven_dtensor import gather_uneven_dtensor_to_full_tensor + + if dist_config.is_main_process(): + logger.info("Starting mFSDP parameter gathering...") + + # Parameter gathering must happen on ALL processes + unsharded_state_dict = { + # Gather all parameters to CPU, and remove the "module." prefix from the Megatron-FSDP class wrapper. + k.removeprefix("module."): gather_uneven_dtensor_to_full_tensor( + v, target_device=torch.device("cpu") + ).to_local() + if isinstance(v, torch.distributed.tensor.DTensor) + else v + for k, v in model.state_dict().items() + } + + # Only main process saves the model + if not dist_config.is_main_process(): + return + + os.makedirs(save_directory, exist_ok=True) + model.module.save_pretrained(save_directory, state_dict=unsharded_state_dict, safe_serialization=True) + logger.info(f"Saved final mFSDP model to {save_directory}") + + # ============================================================================ # Dataloader Checkpointing # ============================================================================ @@ -439,7 +584,7 @@ def load_dataloader( ) return dataloader - dataloader_state = torch.load(dataloader_path) + dataloader_state = torch.load(dataloader_path, weights_only=True) if ( dataloader.num_workers != dataloader_state["num_workers"] diff --git a/bionemo-recipes/recipes/llama3_native_te/collator.py b/bionemo-recipes/recipes/llama3_native_te/collator.py index 63b4716410..dcf998c797 100644 --- a/bionemo-recipes/recipes/llama3_native_te/collator.py +++ b/bionemo-recipes/recipes/llama3_native_te/collator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Data collator for THD input format tests. +"""Data collators for sequence packing and context parallel training. This should eventually get moved to a separate package, or possibly upstreamed into `transformers`. """ @@ -674,6 +674,16 @@ def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tupl def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False): + """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths. + + Args: + features: List of tokenized samples, each containing at least ``input_ids``. + return_position_ids: Whether to return position ids for each token. + + Returns: + A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and + ``max_length_q``/``max_length_k``. + """ is_labels_provided = "labels" in features[0] sample_lengths = [len(sample["input_ids"]) for sample in features] @@ -920,7 +930,7 @@ def process_tensor_bshd(val): class BatchType(TypedDict): - """The fields in the batch dictionary fo THD context parallel.""" + """The fields in the batch dictionary for THD context parallel.""" input_ids: torch.Tensor labels: torch.Tensor | None diff --git a/bionemo-recipes/recipes/llama3_native_te/dataset.py b/bionemo-recipes/recipes/llama3_native_te/dataset.py index 6c0b47cf68..3964056d2f 100644 --- a/bionemo-recipes/recipes/llama3_native_te/dataset.py +++ b/bionemo-recipes/recipes/llama3_native_te/dataset.py @@ -239,7 +239,6 @@ def create_thd_dataloader( prefetch_factor: The prefetch factor to use for the dataloader. max_seq_length: The maximum length of sequences (window size). stride: The stride for windowing (overlap = stride tokens). - seed: The seed to use for the distributed sampler and data collator. buffer_size: The buffer size for shuffle. use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state. text_column: Name of the column containing genomic sequences (default: "text"). diff --git a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py index 6cd452ebb8..d01024f04c 100644 --- a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py +++ b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-Apache2 # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/bionemo-recipes/recipes/llama3_native_te/genomic_dataset.py b/bionemo-recipes/recipes/llama3_native_te/genomic_dataset.py index 12d6bad16f..0e35550c65 100644 --- a/bionemo-recipes/recipes/llama3_native_te/genomic_dataset.py +++ b/bionemo-recipes/recipes/llama3_native_te/genomic_dataset.py @@ -17,13 +17,12 @@ Core functions for genomic data preprocessing during training: - make_upper_case: Convert lowercase tokens to uppercase -- Evo2MaskingConstants: Standard DNA tokens and control characters Adapted from NeMo's Evo2 implementation. """ from dataclasses import dataclass -from typing import Any, ClassVar +from typing import Any import torch @@ -47,16 +46,6 @@ def _make_upper_case(tokens, lowercase_start=97, lowercase_end=122, case_diff=32 return uppercase_tensor, lowercase_mask -class Evo2MaskingConstants: - """Constants used in Evo2 genomic sequence masking.""" - - # Standard DNA tokens: A, C, G, T (both uppercase and lowercase) - DNA_TOKENS: ClassVar[list[int]] = [65, 67, 71, 84, 97, 99, 103, 116] - - # Control characters used in data formatting - CONTROL_TAGS: ClassVar[list[int]] = [64, 35] # '@', '#' - - @dataclass class GenomicDataCollator: """Wrapper collator that adds genomic-specific masking to any base collator. diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml index d6c181598f..c50909a95b 100644 --- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml @@ -75,6 +75,19 @@ fp8_stats_config: fp8_stats_file: ./fp8_debugging_stats.yaml fp8_log_dir: ./log_fp8_stats +# mFSDP config +fully_shard_kwargs: + zero_dp_strategy: "optim_grads_params" + calculate_per_token_loss: false + init_model_with_meta_device: ${use_meta_device} + check_for_nan_in_grad: true + grad_reduce_in_fp32: false + preserve_fp32_weights: true + overlap_grad_reduce: true + overlap_param_gather: true + sync_model_each_microbatch: true + average_in_collective: false + profiler: enabled: false start_step: 10 diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py index d4821be1f7..433edd084a 100644 --- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py @@ -62,7 +62,7 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embed_tokens layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard # deviation. self.model.embed_tokens.to_empty(device="cuda") @@ -363,9 +363,10 @@ def forward( ) -class NVLlamaForSequenceClassification( # noqa: D101 +class NVLlamaForSequenceClassification( transformers.modeling_layers.GenericForSequenceClassification, NVLlamaPreTrainedModel -): ... +): + """Llama3 model with sequence classification head.""" class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestionAnswering, NVLlamaPreTrainedModel): @@ -374,9 +375,10 @@ class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestio base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` -class NVLlamaForTokenClassification( # noqa: D101 +class NVLlamaForTokenClassification( transformers.modeling_layers.GenericForTokenClassification, NVLlamaPreTrainedModel -): ... +): + """Llama3 model with token classification head.""" torch._dynamo.config.capture_scalar_outputs = True diff --git a/bionemo-recipes/recipes/llama3_native_te/requirements.txt b/bionemo-recipes/recipes/llama3_native_te/requirements.txt index 073d9b39e3..672a3b7e1b 100644 --- a/bionemo-recipes/recipes/llama3_native_te/requirements.txt +++ b/bionemo-recipes/recipes/llama3_native_te/requirements.txt @@ -1,5 +1,6 @@ datasets hydra-core +megatron-fsdp torch torchao!=0.14.0 torchdata diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_genomic_dataset.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_genomic_dataset.py index 0aa01c260c..f23265e719 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_genomic_dataset.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_genomic_dataset.py @@ -29,7 +29,7 @@ def tokenizer(tokenizer_path): return AutoTokenizer.from_pretrained(tokenizer_path) -# Tests for GenomicDataCollatorForCLM +# Tests for GenomicDataCollator def test_collator_basic(tokenizer): """Test basic collator functionality.""" base = DataCollatorForLanguageModeling(tokenizer, mlm=False) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py index 0fb725092a..effffbadf4 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py @@ -25,6 +25,7 @@ from train_ddp import main as main_ddp from train_fsdp2 import main as main_fsdp2 from train_fsdp2_cp import main as main_fsdp2_cp +from train_mfsdp_cp import main as main_mfsdp_cp # TODO(@jomitchell): Delete once https://nvbugspro.nvidia.com/bug/5458694 is fixed. @@ -420,7 +421,7 @@ def test_train_fsdp2_fp8_thd(tmp_path, recipe_path): @requires_datacenter_hardware def test_sanity_fsdp2_cp(tmp_path, recipe_path): - # Run the training script with Hydra configuration overrides + """Test FSDP2 with context parallelism training on a single GPU (cp_size=1).""" with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): sanity_config = compose( config_name="L0_sanity_cp", @@ -439,6 +440,49 @@ def test_sanity_fsdp2_cp(tmp_path, recipe_path): assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite" +def test_sanity_convergence_mfsdp_cp_bshd(tmp_path, recipe_path): + """Test Megatron-FSDP with context parallelism training on a single GPU (cp_size=1) using BSHD format.""" + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity_cp", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "config_kwargs.attn_input_format=bshd", + "config_kwargs.self_attn_mask_type=causal", + ], + ) + + final_loss = main_mfsdp_cp(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" + + +@requires_datacenter_hardware +def test_sanity_convergence_mfsdp_cp_thd(tmp_path, recipe_path): + """Test Megatron-FSDP with context parallelism training on a single GPU (cp_size=1) using THD format.""" + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + sanity_config = compose( + config_name="L0_sanity_cp", + overrides=[ + f"+wandb.dir={tmp_path}", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.resume_from_checkpoint=false", + "config_kwargs.attn_input_format=thd", + "config_kwargs.self_attn_mask_type=padding_causal", + ], + ) + + final_loss = main_mfsdp_cp(sanity_config) + gc.collect() + torch.cuda.empty_cache() + + assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0" + + @requires_fp8 def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path): """Test that FP8 stats logging creates the expected log files.""" diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py index d64e6e4eed..0d1e4f8e6b 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py @@ -90,12 +90,11 @@ def test_multi_gpu_train_ddp(recipe_path): "torchrun", "--standalone", "--nproc_per_node", - "2", # 2 processes = 2 GPUs - "--standalone", # Single node mode + "2", "train_ddp.py", "--config-name", "L0_sanity", - "num_train_steps=4", # Just 4 steps for speed + "num_train_steps=4", ], recipe_path, ) @@ -118,12 +117,11 @@ def test_multi_gpu_train_fsdp2(recipe_path): "torchrun", "--standalone", "--nproc_per_node", - "2", # 2 processes = 2 GPUs - "--standalone", # Single node mode + "2", "train_fsdp2.py", "--config-name", "L0_sanity", - "num_train_steps=4", # Just 4 steps for speed + "num_train_steps=4", ], recipe_path, ) @@ -144,7 +142,6 @@ def test_multi_gpu_train_ddp_with_checkpointing(tmp_path, recipe_path): "--standalone", "--nproc_per_node", "2", - "--standalone", "train_ddp.py", "--config-name", "L0_sanity", @@ -177,7 +174,6 @@ def test_multi_gpu_train_fsdp2_with_checkpointing(tmp_path, recipe_path): "--standalone", "--nproc_per_node", "2", - "--standalone", "train_fsdp2.py", "--config-name", "L0_sanity", @@ -197,12 +193,12 @@ def test_multi_gpu_train_fsdp2_with_checkpointing(tmp_path, recipe_path): @requires_multi_gpu def test_multi_gpu_train_te_fsdp2_cp_bshd(tmp_path, recipe_path): + """Test FSDP2 with context parallelism on 2 GPUs using BSHD input format.""" run_train_cmd( [ "torchrun", "--standalone", "--nproc_per_node=2", - "--standalone", "train_fsdp2_cp.py", "--config-name", "L0_sanity_cp", @@ -221,12 +217,12 @@ def test_multi_gpu_train_te_fsdp2_cp_bshd(tmp_path, recipe_path): @requires_multi_gpu @requires_datacenter_hardware def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path): + """Test FSDP2 with context parallelism on 2 GPUs using THD input format.""" run_train_cmd( [ "torchrun", "--standalone", "--nproc_per_node=2", - "--standalone", "train_fsdp2_cp.py", "--config-name", "L0_sanity_cp", @@ -242,6 +238,76 @@ def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path): ) +@requires_multi_gpu +def test_multi_gpu_train_mfsdp_cp(tmp_path, recipe_path): + """Test Megatron-FSDP with context parallelism on 2 GPUs (cp_size=1, dp=2).""" + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node=2", + "train_mfsdp_cp.py", + "--config-name", + "L0_sanity_cp", + "num_train_steps=10", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.save_every_n_steps=5", + "cp_size=1", + "use_sequence_packing=false", + "config_kwargs.attn_input_format=bshd", + "config_kwargs.self_attn_mask_type=causal", + ], + recipe_path, + ) + + +@requires_multi_gpu +def test_multi_gpu_train_mfsdp_cp_bshd(tmp_path, recipe_path): + """Test Megatron-FSDP with context parallelism on 2 GPUs (cp_size=2) using BSHD format.""" + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node=2", + "train_mfsdp_cp.py", + "--config-name", + "L0_sanity_cp", + "num_train_steps=10", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.save_every_n_steps=5", + "cp_size=2", + "use_sequence_packing=false", + "config_kwargs.attn_input_format=bshd", + "config_kwargs.self_attn_mask_type=causal", + ], + recipe_path, + ) + + +@requires_multi_gpu +@requires_datacenter_hardware +def test_multi_gpu_train_mfsdp_cp_thd(tmp_path, recipe_path): + """Test Megatron-FSDP with context parallelism on 2 GPUs (cp_size=2) using THD format.""" + run_train_cmd( + [ + "torchrun", + "--standalone", + "--nproc_per_node=2", + "train_mfsdp_cp.py", + "--config-name", + "L0_sanity_cp", + "num_train_steps=10", + f"checkpoint.ckpt_dir={tmp_path}", + "checkpoint.save_every_n_steps=5", + "cp_size=2", + "use_sequence_packing=true", + "config_kwargs.attn_input_format=thd", + "config_kwargs.self_attn_mask_type=padding_causal", + ], + recipe_path, + ) + + nsys_available = subprocess.run(["which", "nsys"], check=False, capture_output=True).returncode == 0 diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index 49d22901d3..7aed3ff6f5 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -13,6 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Distributed Data Parallel (DDP) training script for Llama 3 with TransformerEngine. + +Each GPU holds a full copy of the model and gradients are synchronized via all-reduce after each +backward pass. This is the simplest distributed strategy and works well for smaller models that fit +in a single GPU's memory. Supports both TE-accelerated (NVLlamaForCausalLM) and standard HuggingFace +(LlamaForCausalLM) models. + +For large models that do not fit on a single GPU, use ``train_fsdp2.py`` instead. +""" + +import gc import logging from contextlib import nullcontext from pathlib import Path @@ -22,7 +33,7 @@ import torch import transformer_engine import transformer_engine.pytorch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh from torch.optim import AdamW from transformer_engine.common.recipe import Format @@ -49,7 +60,7 @@ def main(args: DictConfig) -> float | None: Returns: float: The loss value for the final batch. """ - # Initialize the distributed configuration, including creating the distributed process group. + # --- Distributed Setup --- dist_config = DistributedConfig() logger.info("Initializing distributed training: %s", dist_config) device = torch.device(f"cuda:{dist_config.local_rank}") @@ -60,11 +71,10 @@ def main(args: DictConfig) -> float | None: if args.fp8_stats_config.enabled: initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) - # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2 - # and MFSDP. + # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2. device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",)) - # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. + # --- Model Configuration --- fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) @@ -76,12 +86,12 @@ def main(args: DictConfig) -> float | None: config_class = LlamaConfig model_class = LlamaForCausalLM - # Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B". + # --- Model Initialization --- config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) # Optionally use transformer engine to initialize only fp8 versions of weights by setting - # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8 - # versions of weights are kept. + # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 + # and fp8 versions of weights are kept. with transformer_engine.pytorch.quantized_model_init( recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs ): @@ -89,10 +99,7 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) - # Create optimizer. - optimizer = AdamW(model.parameters(), **args.adamw_kwargs) - scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) - + # --- Distributed Wrapping (DDP) --- if args.fp8_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) @@ -104,18 +111,25 @@ def main(args: DictConfig) -> float | None: device_mesh=device_mesh["dp"], ) - if args.use_sequence_packing: - train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) - else: - train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) + # --- Optimizer & Scheduler --- + # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). + optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore + scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) if args.use_torch_compile: # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. model = torch.compile(model) - # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0. + # --- Data Loading --- + if args.use_sequence_packing: + train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) + else: + train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) + + # --- Checkpoint Resume --- ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None if args.checkpoint.resume_from_checkpoint and ckpt_path: + logger.info("Attempting to load checkpoint from %s", ckpt_path) model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_ddp( model=model, optimizer=optimizer, @@ -124,22 +138,27 @@ def main(args: DictConfig) -> float | None: dist_config=dist_config, dataloader=train_dataloader, ) + logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) else: + logger.info("No checkpoint to load, starting from scratch") start_step = 0 epoch = 0 perf_logger = PerfLogger(dist_config, args) - # Training loop + gc.collect() + torch.cuda.empty_cache() + + # --- Training Loop --- + logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) step = start_step micro_step = 0 # Gradient accumulation step counter while step < args.num_train_steps: for batch in train_dataloader: - print(batch["input_ids"].shape) - batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa PLW2901 + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 micro_step += 1 - # Use no_sync to prevent gradient synchronization until the last microbatch + # DDP requires no_sync to skip all-reduce until the last microbatch in the accumulation window. with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext(): # Forward pass with mixed precision. with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): @@ -157,7 +176,7 @@ def main(args: DictConfig) -> float | None: micro_step = 0 # Compute and clip gradient norms. - total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Step optimizer. optimizer.step() @@ -191,7 +210,7 @@ def main(args: DictConfig) -> float | None: epoch += 1 dataset_or_sampler.set_epoch(epoch) - # Save final model to a .safetensors file. + # --- Cleanup --- if args.checkpoint.save_final_model and ckpt_path: save_final_model_ddp( model=model, @@ -199,7 +218,6 @@ def main(args: DictConfig) -> float | None: dist_config=dist_config, ) - # Clean up distributed training perf_logger.finish() torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index be3e80a514..558d27366d 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -13,6 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Fully Sharded Data Parallel v2 (FSDP2) training script for Llama 3 with TransformerEngine. + +Model weights and optimizer states are sharded across GPUs, allowing training of models that exceed +the memory of a single GPU. Supports both TE-accelerated (NVLlamaForCausalLM) and standard +HuggingFace (LlamaForCausalLM) models. + +For very long sequences, use ``train_fsdp2_cp.py`` which adds Context Parallelism on top of FSDP2. +""" + import gc import logging from contextlib import nullcontext @@ -57,7 +66,7 @@ def main(args: DictConfig) -> float | None: Returns: float: The loss value for the final batch. """ - # Initialize the distributed configuration, including creating the distributed process group. + # --- Distributed Setup --- dist_config = DistributedConfig() logger.info("Initializing distributed training: %s", dist_config) device = torch.device(f"cuda:{dist_config.local_rank}") @@ -68,10 +77,9 @@ def main(args: DictConfig) -> float | None: if args.fp8_stats_config.enabled: initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled) - # Create a device mesh for FSDP. device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",)) - # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. + # --- Model Configuration --- fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) @@ -83,12 +91,12 @@ def main(args: DictConfig) -> float | None: config_class = LlamaConfig model_class = LlamaForCausalLM - # Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B". + # --- Model Initialization --- config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) # Optionally use transformer engine to initialize only fp8 versions of weights by setting - # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8 - # versions of weights are kept. + # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 + # and fp8 versions of weights are kept. with ( torch.device("meta") if args.use_meta_device else nullcontext(), transformer_engine.pytorch.quantized_model_init( @@ -99,7 +107,7 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) - # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. + # --- Distributed Wrapping (FSDP2) --- # Each decoder layer should be individually sharded before sharding the full model. for layer in model.model.layers: fully_shard(layer, mesh=device_mesh["dp"]) @@ -118,23 +126,25 @@ def main(args: DictConfig) -> float | None: if args.fp8_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) - # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). + # --- Optimizer & Scheduler --- + # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + if args.use_torch_compile: + # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. + model = torch.compile(model) + + # --- Data Loading --- if args.use_sequence_packing: train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) else: train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) - if args.use_torch_compile: - # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. - model = torch.compile(model) - - # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0. + # --- Checkpoint Resume --- ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None if args.checkpoint.resume_from_checkpoint and ckpt_path: - logger.info(f"Attempting to load checkpoint from {ckpt_path}") + logger.info("Attempting to load checkpoint from %s", ckpt_path) model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_fsdp2( model=model, optimizer=optimizer, @@ -144,7 +154,7 @@ def main(args: DictConfig) -> float | None: dataloader=train_dataloader, process_group=device_mesh.get_group("dp"), ) - logger.info(f"Checkpoint loaded, resuming from step {start_step}, epoch {epoch}") + logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) else: logger.info("No checkpoint to load, starting from scratch") start_step = 0 @@ -155,8 +165,8 @@ def main(args: DictConfig) -> float | None: gc.collect() torch.cuda.empty_cache() - # Training loop - logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}") + # --- Training Loop --- + logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) step = start_step micro_step = 0 # Gradient accumulation step counter while step < args.num_train_steps: @@ -176,7 +186,7 @@ def main(args: DictConfig) -> float | None: # Log microbatch step data for accumulation metrics perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) - # Gradient accumulation - only step optimizer after accumulating gradients + # The end of a "full" step (i.e. after possibly multiple gradient accumulation steps). if micro_step % args.grad_acc_steps == 0: micro_step = 0 @@ -217,7 +227,7 @@ def main(args: DictConfig) -> float | None: epoch += 1 dataset_or_sampler.set_epoch(epoch) - # Save final model to a .safetensors file. + # --- Cleanup --- if args.checkpoint.save_final_model and ckpt_path: save_final_model_fsdp2( model=model, @@ -229,7 +239,6 @@ def main(args: DictConfig) -> float | None: if args.checkpoint.async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None: _ckpt_futures["fsdp2"].result() - # Clean up distributed training perf_logger.finish() torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index 742e21a63d..9ad3d0e297 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -13,6 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""FSDP2 with Context Parallelism training script for Llama 3 with TransformerEngine. + +Combines Fully Sharded Data Parallel v2 with Context Parallelism (CP), where each sequence is +split across multiple GPUs along the sequence dimension. This is useful for training with very long +sequences that do not fit into a single GPU's memory even with FSDP2 alone. Only supports +TE-accelerated models (NVLlamaForCausalLM). + +For standard FSDP2 training without context parallelism, use ``train_fsdp2.py`` instead. +""" + import gc import logging from contextlib import nullcontext @@ -28,7 +38,13 @@ from torch.optim import AdamW from transformer_engine.common.recipe import Format -from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint +from checkpoint import ( + _ckpt_futures, + load_checkpoint_fsdp2, + save_checkpoint_fsdp2, + save_final_model_fsdp2, + should_save_checkpoint, +) from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig @@ -43,37 +59,36 @@ @hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2") def main(args: DictConfig) -> float | None: - """Train Llama3 with TE layers using FSDP2. + """Train Llama3 with TE layers using FSDP2 with Context Parallelism. Returns: float: The loss value for the final batch. """ - # Initialize the distributed configuration, including creating the distributed process group. + # --- Distributed Setup --- dist_config = DistributedConfig() logger.info("Initializing distributed training: %s", dist_config) device = torch.device(f"cuda:{dist_config.local_rank}") torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # Create a device mesh for FSDP. device_mesh = init_device_mesh( "cuda", mesh_shape=(dist_config.world_size // args.cp_size, args.cp_size), mesh_dim_names=("dp", "cp"), ) - logger.info(f"Created device mesh: {device_mesh}") + logger.info("Created device mesh: %s", device_mesh) - # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. + # --- Model Configuration --- fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs ) - # Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B". + # --- Model Initialization --- config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) # Optionally use transformer engine to initialize only fp8 versions of weights by setting - # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8 - # versions of weights are kept. + # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 + # and fp8 versions of weights are kept. with ( torch.device("meta") if args.use_meta_device else nullcontext(), transformer_engine.pytorch.quantized_model_init( @@ -84,7 +99,7 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) - # Create a flattened mesh for FSDP2 sharding. This will shard the model across both the DP and CP ranks. + # --- Distributed Wrapping (FSDP2 + CP) --- cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp") # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. @@ -105,7 +120,8 @@ def main(args: DictConfig) -> float | None: # TE layers require special handling to initialize the weights from the meta device. model.init_empty_weights() - # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). + # --- Optimizer & Scheduler --- + # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) @@ -113,11 +129,16 @@ def main(args: DictConfig) -> float | None: # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency. model = torch.compile(model) - # Create the context-aware dataloader. We only create the dataloader on rank 0 and wrap it in a - # ContextParallelDataLoaderWrapper that will shard and distribute the data across the context parallelism group. + # --- Data Loading --- + # Create the context-aware dataloader. if args.dataset.get("pad_sequences_to_be_divisible_by", None) is None: + # The dual chunk algorithm gives each CP rank 2 chunks from each sequence, so we need each sequence to be + # divisible by cp_mesh.size() * 2. logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2") OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2) + + # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP) + # ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline. if device_mesh["cp"].get_local_rank() == 0: if args.use_sequence_packing: train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) @@ -135,12 +156,13 @@ def main(args: DictConfig) -> float | None: train_dataloader = None dataset_or_sampler = None + # On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0. train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"]) - # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0. + # --- Checkpoint Resume --- ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None if args.checkpoint.resume_from_checkpoint and ckpt_path: - logger.info(f"Attempting to load checkpoint from {ckpt_path}") + logger.info("Attempting to load checkpoint from %s", ckpt_path) model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_fsdp2( model=model, optimizer=optimizer, @@ -150,7 +172,7 @@ def main(args: DictConfig) -> float | None: dataloader=train_dataloader, process_group=cp_dp_mesh.get_group(), ) - logger.info(f"Checkpoint loaded, resuming from step {start_step}, epoch {epoch}") + logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) else: logger.info("No checkpoint to load, starting from scratch") start_step = 0 @@ -161,8 +183,8 @@ def main(args: DictConfig) -> float | None: gc.collect() torch.cuda.empty_cache() - # Training loop - logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}") + # --- Training Loop --- + logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) step = start_step micro_step = 0 # Gradient accumulation step counter while step < args.num_train_steps: @@ -185,7 +207,7 @@ def main(args: DictConfig) -> float | None: # Log microbatch step data for accumulation metrics perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) - # Gradient accumulation - only step optimizer after accumulating gradients + # The end of a "full" step (i.e. after possibly multiple gradient accumulation steps). if micro_step % args.grad_acc_steps == 0: micro_step = 0 @@ -227,7 +249,7 @@ def main(args: DictConfig) -> float | None: if dataset_or_sampler is not None: # The dataset only exists on rank 0 dataset_or_sampler.set_epoch(epoch) - # Save final model to a .safetensors file. + # --- Cleanup --- if args.checkpoint.save_final_model and ckpt_path: save_final_model_fsdp2( model=model, @@ -235,7 +257,10 @@ def main(args: DictConfig) -> float | None: dist_config=dist_config, ) - # Clean up distributed training + # Make sure we don't have any outstanding checkpoint save futures. + if args.checkpoint.async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None: + _ckpt_futures["fsdp2"].result() + perf_logger.finish() torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_mfsdp_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_mfsdp_cp.py new file mode 100644 index 0000000000..742f7cff01 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/train_mfsdp_cp.py @@ -0,0 +1,281 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron-FSDP with Context Parallelism training script for Llama 3 with TransformerEngine. + +Combines Megatron-FSDP with Context Parallelism (CP), where each sequence is split across multiple +GPUs along the sequence dimension. This is useful for training with very long sequences that do not +fit into a single GPU's memory even with FSDP alone. Only supports TE-accelerated models +(NVLlamaForCausalLM). + +For standard FSDP2 training without context parallelism, use ``train_fsdp2.py`` instead. +For FSDP2 with context parallelism, use ``train_fsdp2_cp.py`` instead. +""" + +import gc +import logging +from pathlib import Path + +import hydra +import nvtx +import torch +import transformer_engine.pytorch +from megatron_fsdp.fully_shard import fully_shard as mfsdp_fully_shard +from omegaconf import DictConfig, OmegaConf +from torch.distributed.device_mesh import init_device_mesh +from torch.optim import AdamW +from transformer_engine.common.recipe import Format + +from checkpoint import ( + load_checkpoint_mfsdp, + save_checkpoint_mfsdp, + save_final_model_mfsdp, + should_save_checkpoint, +) +from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel +from dataset import create_bshd_dataloader, create_thd_dataloader +from distributed_config import DistributedConfig +from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM +from perf_logger import PerfLogger +from scheduler import get_cosine_annealing_schedule_with_warmup + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2") +def main(args: DictConfig) -> float | None: + """Train Llama3 with TE layers using Megatron-FSDP with Context Parallelism. + + Returns: + float: The loss value for the final batch. + """ + # --- Distributed Setup --- + dist_config = DistributedConfig() + logger.info("Initializing distributed training: %s", dist_config) + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + # Create a 3D device mesh (dp, cp, tp) where tp is a dummy dimension of size 1 required by mfsdp. + dp_size = dist_config.world_size // args.cp_size + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(dp_size, args.cp_size, 1), + mesh_dim_names=("dp", "cp", "tp"), + ) + logger.info("Created device mesh: %s", device_mesh) + + # --- Model Configuration --- + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + + # --- Model Initialization --- + config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + + # mfsdp does not support tied weight parameters. If tie_word_embeddings is enabled, we need to untie them so that + # lm_head.weight and embed_tokens.weight are separate parameters for the mfsdp optimizer buffer. + if config.tie_word_embeddings: + logger.warning( + "Megatron-FSDP does not support tied weight parameters. Setting tie_word_embeddings=False. " + "This means lm_head.weight will be a separate parameter from embed_tokens.weight." + ) + config.tie_word_embeddings = False + + # Optionally use transformer engine to initialize only fp8 versions of weights by setting + # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 + # and fp8 versions of weights are kept. + # NOTE: Meta device initialization for mfsdp is handled by the `init_model_with_meta_device` kwarg in + # fully_shard_kwargs, so we do NOT use `torch.device("meta")` here (unlike train_fsdp2_cp.py). + with transformer_engine.pytorch.quantized_model_init( + recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs + ): + model = NVLlamaForCausalLM(config) + + logger.info("Initialized Model:\n%s", model) + + # --- Optimizer (created before mfsdp wrapping, will be wrapped by fully_shard) --- + # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). + optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore + + # --- Distributed Wrapping (Megatron-FSDP + CP) --- + model, optimizer = mfsdp_fully_shard( + module=model, + optimizer=optimizer, + fsdp_unit_modules=[ + transformer_engine.pytorch.TransformerLayer, + transformer_engine.pytorch.LayerNorm, + transformer_engine.pytorch.LayerNormLinear, + ], + device_mesh=device_mesh, + dp_shard_dim="dp", + tp_dim="tp", + **args.fully_shard_kwargs, + ) + + # Attach the CP group to each transformer layer. + for layer in model.module.model.layers: + layer.set_context_parallel_group( + device_mesh["cp"].get_group(), + torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), + torch.cuda.Stream(), + ) + + # --- Scheduler (must be created after mfsdp wrapping since fully_shard modifies the optimizer) --- + scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) + + if args.use_torch_compile: + logger.warning( + "BIONEMO-2977: Using torch.compile with mfsdp is currently not supported. `use_torch_compile` was set to " + "true, but will be ignored." + ) + + # --- Data Loading --- + # Create the context-aware dataloader. + if args.dataset.get("pad_sequences_to_be_divisible_by", None) is None: + # The dual chunk algorithm gives each CP rank 2 chunks from each sequence, so we need each sequence to be + # divisible by cp_mesh.size() * 2. + logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2") + OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2) + + # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP) + # ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline. + if device_mesh["cp"].get_local_rank() == 0: + if args.use_sequence_packing: + train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) + else: + train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset) + + train_dataloader.collate_fn = DataCollatorForContextParallel( + collator=train_dataloader.collate_fn, + device_mesh=device_mesh, + qkv_format=args.config_kwargs.attn_input_format, + is_causal_lm=True, + ) + + else: + train_dataloader = None + dataset_or_sampler = None + + # On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0. + train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"]) + + # --- Checkpoint Resume --- + ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_mfsdp" if args.checkpoint.ckpt_dir else None + if args.checkpoint.resume_from_checkpoint and ckpt_path: + logger.info("Attempting to load checkpoint from %s", ckpt_path) + model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_mfsdp( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + dist_config=dist_config, + dataloader=train_dataloader, + ) + logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) + else: + logger.info("No checkpoint to load, starting from scratch") + start_step = 0 + epoch = 0 + + perf_logger = PerfLogger(dist_config, args) + + gc.collect() + torch.cuda.empty_cache() + + # --- Training Loop --- + logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps) + step = start_step + micro_step = 0 # Gradient accumulation step counter + while step < args.num_train_steps: + for batch in train_dataloader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 + + micro_step += 1 + + # Forward pass with mixed precision. + with nvtx.annotate("Forward pass", color="green"): + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + outputs = model(**batch) + + # Backward pass - scale loss by grad_acc_steps for proper gradient averaging + loss = outputs.loss / args.grad_acc_steps + + with nvtx.annotate("Backward pass", color="red"): + loss.backward() + + # Log microbatch step data for accumulation metrics + perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs) + + # The end of a "full" step (i.e. after possibly multiple gradient accumulation steps). + if micro_step % args.grad_acc_steps == 0: + micro_step = 0 + + # Compute and clip gradient norms. + # NOTE: grad clipping with mfsdp has been reported to cause hangs in some configurations. + # If you experience hangs, try commenting out the clip_grad_norm_ call. + total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # Step optimizer. + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + perf_logger.log_step( + step=step, + grad_norm=total_norm, + lr=optimizer.param_groups[0]["lr"], + ) + + if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): + save_checkpoint_mfsdp( + model=model, + optimizer=optimizer, + scheduler=scheduler, + ckpt_path=ckpt_path, + step=step, + epoch=epoch, + dist_config=dist_config, + dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, + max_checkpoints=args.checkpoint.max_checkpoints, + ) + + step += 1 + if step >= args.num_train_steps: + break + + # Dataloader exhausted, incrementing epoch + epoch += 1 + if dataset_or_sampler is not None: # The dataset only exists on rank 0 + dataset_or_sampler.set_epoch(epoch) + + # --- Cleanup --- + if args.checkpoint.save_final_model and ckpt_path: + save_final_model_mfsdp( + model=model, + save_directory=ckpt_path / "final_model", + dist_config=dist_config, + ) + + perf_logger.finish() + torch.distributed.destroy_process_group() + + return perf_logger.min_loss + + +if __name__ == "__main__": + main()