diff --git a/bionemo-recipes/models/esm2/src/esm/collator.py b/bionemo-recipes/models/esm2/src/esm/collator.py
index 63b4716410..dcf998c797 100644
--- a/bionemo-recipes/models/esm2/src/esm/collator.py
+++ b/bionemo-recipes/models/esm2/src/esm/collator.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Data collator for THD input format tests.
+"""Data collators for sequence packing and context parallel training.
This should eventually get moved to a separate package, or possibly upstreamed into `transformers`.
"""
@@ -674,6 +674,16 @@ def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tupl
def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False):
+ """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths.
+
+ Args:
+ features: List of tokenized samples, each containing at least ``input_ids``.
+ return_position_ids: Whether to return position ids for each token.
+
+ Returns:
+ A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and
+ ``max_length_q``/``max_length_k``.
+ """
is_labels_provided = "labels" in features[0]
sample_lengths = [len(sample["input_ids"]) for sample in features]
@@ -920,7 +930,7 @@ def process_tensor_bshd(val):
class BatchType(TypedDict):
- """The fields in the batch dictionary fo THD context parallel."""
+ """The fields in the batch dictionary for THD context parallel."""
input_ids: torch.Tensor
labels: torch.Tensor | None
diff --git a/bionemo-recipes/models/llama3/README.md b/bionemo-recipes/models/llama3/README.md
index 758de85efb..a5fa2efbaf 100644
--- a/bionemo-recipes/models/llama3/README.md
+++ b/bionemo-recipes/models/llama3/README.md
@@ -1,6 +1,166 @@
-# 🚧 Llama-3.1 Optimized with NVIDIA TransformerEngine
+# Llama-3.1 Optimized with NVIDIA TransformerEngine
-This folder contains source code and tests for an Llama-3.1 model that inherits from the transformers `PreTrainedModel`
-class and uses TransformerEngine layers.
+This folder contains source code and tests for Llama-3.\* style models that inherit from the transformers
+`PreTrainedModel` class and uses TransformerEngine layers. Unlike the ESM-2 model, we do not currently distribute
+pre-converted TE checkpoints on HuggingFace Hub. Instead, users can convert existing Llama 3 checkpoints from
+HuggingFace using the provided conversion utilities.
-This folder is currently work in progress and is not yet ready for general use.
+## Feature support
+
+The Llama-3 implementation natively supports the following TransformerEngine-provided optimizations:
+
+| Feature | Support |
+| --------------------------------------- | -------------------------------------------------------------------------------- |
+| **FP8** | ✅ Supported on compute capacity 9.0 and above (Hopper+) |
+| **MXFP8** | ✅ Supported on compute capacity 10.0 and 10.3 (Blackwell), 12.0 support pending |
+| **Sequence Packing / THD input format** | ✅ Supported |
+| **FP8 with THD input format** | ✅ Supported where FP8 is supported |
+| **Import from HuggingFace checkpoints** | ✅ Supported |
+| **Export to HuggingFace checkpoints** | ✅ Supported |
+| **KV-cache inference** | ✅ Supported (including beam search) |
+| **Context Parallelism** | ✅ Supported |
+| **Tensor Parallelism** | 🚧 Under development |
+
+Refer to [BioNeMo Recipes](../../recipes/llama3_native_te/README.md) for more details on how to use these features to accelerate model
+training and inference with native PyTorch training loops.
+
+## Inference Examples
+
+### Quick start: convert and run
+
+```python
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from convert import convert_llama_hf_to_te
+
+# Load the original HuggingFace Llama 3 model
+model_hf = AutoModelForCausalLM.from_pretrained(
+ "meta-llama/Llama-3.2-1B-Instruct", dtype=torch.bfloat16
+)
+
+# Convert to TransformerEngine.
+model_te = convert_llama_hf_to_te(model_hf)
+model_te.to("cuda")
+
+tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
+tokenizer.pad_token = tokenizer.eos_token
+
+inputs = tokenizer("The quick brown fox", return_tensors="pt")
+inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+with torch.no_grad():
+ output_ids = model_te.generate(**inputs, max_new_tokens=16, use_cache=False)
+
+print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
+```
+
+### Inference with KV-cache
+
+For efficient autoregressive generation, use the TE-provided `InferenceParams` KV-cache:
+
+```python
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from transformer_engine.pytorch.attention import InferenceParams
+
+from convert import convert_llama_hf_to_te
+
+model_hf = AutoModelForCausalLM.from_pretrained(
+ "meta-llama/Llama-3.2-1B-Instruct", torch_dtype=torch.bfloat16
+)
+model_te = convert_llama_hf_to_te(
+ model_hf, attn_input_format="thd", self_attn_mask_type="padding_causal"
+)
+model_te.to("cuda")
+
+tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
+
+inputs = tokenizer("The quick brown fox", return_tensors="pt")
+inputs = {k: v.to("cuda") for k, v in inputs.items()}
+
+# Allocate KV-cache
+past_key_values = InferenceParams(
+ max_batch_size=1,
+ max_sequence_length=256,
+ num_heads_kv=model_te.config.num_key_value_heads,
+ head_dim_k=model_te.config.hidden_size // model_te.config.num_attention_heads,
+ dtype=torch.bfloat16,
+ qkv_format="thd",
+ max_ctx_len=256,
+)
+
+for layer_number in range(1, model_te.config.num_hidden_layers + 1):
+ past_key_values.allocate_memory(layer_number)
+
+with torch.no_grad():
+ output_ids = model_te.generate(
+ **inputs,
+ max_new_tokens=16,
+ use_cache=True,
+ past_key_values=past_key_values,
+ )
+
+print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
+```
+
+## Recipe Links
+
+Training recipes are available in the `bionemo-recipes/recipes/` directory:
+
+- **[llama3_native_te](../../recipes/llama3_native_te/)** - Demonstrates training with a native PyTorch training loop
+ using FSDP2, including FP8, sequence packing, and context parallelism.
+
+## Converting Between Model Formats
+
+This section explains how to convert between Hugging Face Transformers and Transformer Engine (TE) Llama 3 model
+formats. The process demonstrates bidirectional conversion: from Transformers to TE format for optimized training and
+inference, and back to Hugging Face Transformers format for sharing and deployment.
+
+### Converting from HF Transformers to TE
+
+```python
+from transformers import AutoModelForCausalLM
+
+from convert import convert_llama_hf_to_te
+
+model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
+model_te = convert_llama_hf_to_te(model_hf)
+model_te.save_pretrained("/path/to/te_checkpoint")
+```
+
+### Converting from TE back to HF Transformers
+
+```python
+from convert import convert_llama_te_to_hf
+from modeling_llama_te import NVLlamaForCausalLM
+
+model_te = NVLlamaForCausalLM.from_pretrained("/path/to/te_checkpoint")
+model_hf = convert_llama_te_to_hf(model_te)
+model_hf.save_pretrained("/path/to/hf_checkpoint")
+```
+
+Once converted back to HF format, the model can be loaded by any library that supports Llama 3, such as
+[vLLM](https://github.com/vllm-project/vllm) or [SGLang](https://github.com/sgl-project/sglang).
+
+### Validating Converted Models
+
+To validate the converted models, refer to the commands in [Inference Examples](#inference-examples) above to load and
+test both the original and converted models to ensure loss and logit values are similar. Additionally, refer to the
+golden value tests in [test_modeling_llama_te.py](tests/test_modeling_llama_te.py).
+
+## Developer Guide
+
+### Running tests
+
+To run tests locally, run `recipes_local_test.py` from the repository root with the model directory as an argument.
+
+```bash
+./ci/scripts/recipes_local_test.py bionemo-recipes/models/llama3/
+```
+
+### Development container
+
+To use the provided devcontainer, use "Dev Containers: Reopen in Container" from the VSCode menu, and choose the
+"BioNeMo Recipes Dev Container" option. To run the tests inside the container, first install the model package in
+editable mode with `pip install -e .`, then run `pytest -v .` in the model directory.
diff --git a/bionemo-recipes/models/llama3/collator.py b/bionemo-recipes/models/llama3/collator.py
index 63b4716410..dcf998c797 100644
--- a/bionemo-recipes/models/llama3/collator.py
+++ b/bionemo-recipes/models/llama3/collator.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Data collator for THD input format tests.
+"""Data collators for sequence packing and context parallel training.
This should eventually get moved to a separate package, or possibly upstreamed into `transformers`.
"""
@@ -674,6 +674,16 @@ def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tupl
def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False):
+ """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths.
+
+ Args:
+ features: List of tokenized samples, each containing at least ``input_ids``.
+ return_position_ids: Whether to return position ids for each token.
+
+ Returns:
+ A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and
+ ``max_length_q``/``max_length_k``.
+ """
is_labels_provided = "labels" in features[0]
sample_lengths = [len(sample["input_ids"]) for sample in features]
@@ -920,7 +930,7 @@ def process_tensor_bshd(val):
class BatchType(TypedDict):
- """The fields in the batch dictionary fo THD context parallel."""
+ """The fields in the batch dictionary for THD context parallel."""
input_ids: torch.Tensor
labels: torch.Tensor | None
diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py
index d4821be1f7..433edd084a 100644
--- a/bionemo-recipes/models/llama3/modeling_llama_te.py
+++ b/bionemo-recipes/models/llama3/modeling_llama_te.py
@@ -62,7 +62,7 @@ def init_empty_weights(self):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
- # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use
+ # The embed_tokens layer is the only non-TE layer in this model we need to deal with. We use
# `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard
# deviation.
self.model.embed_tokens.to_empty(device="cuda")
@@ -363,9 +363,10 @@ def forward(
)
-class NVLlamaForSequenceClassification( # noqa: D101
+class NVLlamaForSequenceClassification(
transformers.modeling_layers.GenericForSequenceClassification, NVLlamaPreTrainedModel
-): ...
+):
+ """Llama3 model with sequence classification head."""
class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestionAnswering, NVLlamaPreTrainedModel):
@@ -374,9 +375,10 @@ class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestio
base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
-class NVLlamaForTokenClassification( # noqa: D101
+class NVLlamaForTokenClassification(
transformers.modeling_layers.GenericForTokenClassification, NVLlamaPreTrainedModel
-): ...
+):
+ """Llama3 model with token classification head."""
torch._dynamo.config.capture_scalar_outputs = True
diff --git a/bionemo-recipes/recipes/esm2_native_te/collator.py b/bionemo-recipes/recipes/esm2_native_te/collator.py
index 63b4716410..dcf998c797 100644
--- a/bionemo-recipes/recipes/esm2_native_te/collator.py
+++ b/bionemo-recipes/recipes/esm2_native_te/collator.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Data collator for THD input format tests.
+"""Data collators for sequence packing and context parallel training.
This should eventually get moved to a separate package, or possibly upstreamed into `transformers`.
"""
@@ -674,6 +674,16 @@ def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tupl
def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False):
+ """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths.
+
+ Args:
+ features: List of tokenized samples, each containing at least ``input_ids``.
+ return_position_ids: Whether to return position ids for each token.
+
+ Returns:
+ A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and
+ ``max_length_q``/``max_length_k``.
+ """
is_labels_provided = "labels" in features[0]
sample_lengths = [len(sample["input_ids"]) for sample in features]
@@ -920,7 +930,7 @@ def process_tensor_bshd(val):
class BatchType(TypedDict):
- """The fields in the batch dictionary fo THD context parallel."""
+ """The fields in the batch dictionary for THD context parallel."""
input_ids: torch.Tensor
labels: torch.Tensor | None
diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging.py b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging.py
index 6cd452ebb8..d01024f04c 100644
--- a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging.py
+++ b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/bionemo-recipes/recipes/esm2_peft_te/collator.py b/bionemo-recipes/recipes/esm2_peft_te/collator.py
index 63b4716410..dcf998c797 100644
--- a/bionemo-recipes/recipes/esm2_peft_te/collator.py
+++ b/bionemo-recipes/recipes/esm2_peft_te/collator.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Data collator for THD input format tests.
+"""Data collators for sequence packing and context parallel training.
This should eventually get moved to a separate package, or possibly upstreamed into `transformers`.
"""
@@ -674,6 +674,16 @@ def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tupl
def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False):
+ """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths.
+
+ Args:
+ features: List of tokenized samples, each containing at least ``input_ids``.
+ return_position_ids: Whether to return position ids for each token.
+
+ Returns:
+ A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and
+ ``max_length_q``/``max_length_k``.
+ """
is_labels_provided = "labels" in features[0]
sample_lengths = [len(sample["input_ids"]) for sample in features]
@@ -920,7 +930,7 @@ def process_tensor_bshd(val):
class BatchType(TypedDict):
- """The fields in the batch dictionary fo THD context parallel."""
+ """The fields in the batch dictionary for THD context parallel."""
input_ids: torch.Tensor
labels: torch.Tensor | None
diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md
index b7416929cc..aa776026d0 100644
--- a/bionemo-recipes/recipes/llama3_native_te/README.md
+++ b/bionemo-recipes/recipes/llama3_native_te/README.md
@@ -16,9 +16,9 @@ bionemo-framework repository. You can download a zipped directory of this folder
## Supported Models and Training Features
-| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism |
-| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- |
-| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | 🚧 | 🚧 | 🚧 |
+| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | Tensor Parallelism |
+| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ |
+| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 |
✅: Supported
🚧: Under development
@@ -42,7 +42,8 @@ To run the container, run:
docker run -it --gpus all --network host --ipc=host --rm -v ${PWD}:/workspace/bionemo llama3_native_te /bin/bash
```
-Alternatively, the dependencies can be installed manually in an environment with CUDA support. See `requirements.txt` for the list of dependencies.
+Alternatively, the dependencies can be installed manually in an environment with CUDA support. See `requirements.txt`
+for the list of dependencies.
### Performance Benchmarks
@@ -70,10 +71,12 @@ Training was performed with BF16 precision.
### Distributed Training
-This recipe supports distributed training using DDP and FSDP2, shown in two separate training entrypoints:
+This recipe supports distributed training using DDP, FSDP2, FSDP2 with Context Parallelism, and Megatron-FSDP with Context Parallelism, shown in four separate training entrypoints:
- [Distributed Data Parallel (DDP)](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html), shown in `train_ddp.py`
- [Fully Sharded Data Parallel 2 (FSDP2)](https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html), shown in `train_fsdp2.py`
+- FSDP2 with Context Parallelism, shown in `train_fsdp2_cp.py`
+- Megatron-FSDP with Context Parallelism, shown in `train_mfsdp_cp.py`
## Commands to Launch Training
@@ -91,6 +94,21 @@ torchrun --nproc_per_node=2 train_fsdp2.py # or train_ddp.py
Multi-Node training is supported with both strategies.
+A convergence test configuration (`L0_convergence`) is also available, which uses a tiny Llama model
+to verify that the training loop can overfit on a small dataset:
+
+```bash
+python train_fsdp2.py --config-name L0_convergence
+```
+
+Gradient accumulation is supported with both strategies. To enable gradient accumulation, set `grad_acc_steps` to the
+number of steps to accumulate gradients before updating the model parameters. This is useful to scale the effective
+batch size while running on a smaller number of GPUs.
+
+```bash
+python train_fsdp2.py --config-name L0_sanity grad_acc_steps=2
+```
+
### FP8 Training
To run training with FP8, enable it by overriding the `fp8_config.enabled=true` configuration parameter. Additional FP8
@@ -106,15 +124,15 @@ We also provide a mechanism to receive tensor data related to FP8 layers during
To enable this please select the following config options.
-```python
+```bash
python train_fsdp2.py \
-fp8_stats_config.enabled=True # whether to log stats or not
-fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy # where to store the logs
-fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml # specifies what stats you want to run. Currently this is saved in this yaml file.
-fp8_config.enabled=True # set this to use FP8 otherwise stats logging wont work
+ fp8_stats_config.enabled=True \
+ fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy \
+ fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml \
+ fp8_config.enabled=True
```
-Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. It is not yet available for `train_mfsdp`.
+Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts.
The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the [NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) in more detail. Below we will cover some very basic elements of the file structure.
@@ -142,6 +160,30 @@ python train_fsdp2.py --config-name L0_sanity \
use_sequence_packing=true
```
+### Context Parallel Training
+
+Context parallelism splits each sequence across multiple GPUs along the sequence dimension, enabling training with very
+long sequences. Use `train_fsdp2_cp.py` with the `L0_sanity_cp` configuration and set `cp_size` to the number of context
+parallelism ranks. Works with both BSHD (no padding) and THD (padding) input formats. Only TE models are supported.
+
+```bash
+torchrun --nproc_per_node=4 train_fsdp2_cp.py --config-name L0_sanity_cp cp_size=2
+```
+
+### Megatron-FSDP with Context Parallelism
+
+Megatron-FSDP (`train_mfsdp_cp.py`) provides an alternative FSDP implementation from the `megatron-fsdp` package with
+context parallelism support. It creates a 3D device mesh `(dp, cp, tp)` where `tp` is a dummy dimension of size 1. Only
+TE models are supported. Note that `torch.compile` is not supported with Megatron-FSDP.
+
+```bash
+# Single GPU (cp_size=1)
+python train_mfsdp_cp.py --config-name L0_sanity_cp cp_size=1
+
+# Multi-GPU with context parallelism
+torchrun --nproc_per_node=4 train_mfsdp_cp.py --config-name L0_sanity_cp cp_size=2
+```
+
## Downloading Pre-Training Data For Offline Training
This recipe is configured to use genomic sequences. The default configuration uses a local test file
diff --git a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py
index a5af170661..7d3263f2f2 100644
--- a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py
+++ b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py
@@ -40,6 +40,10 @@
logger = logging.getLogger(__name__)
+
+# Tracks in-flight async checkpoint futures keyed by strategy name (e.g. "fsdp2").
+# Each entry holds the Future returned by dcp_async_save so we can await it before starting
+# the next async save or before shutting down.
_ckpt_futures: dict = {}
@@ -379,6 +383,147 @@ def save_final_model_fsdp2(
logger.info(f"Saved final FSDP2 model to {save_directory} (weights + config only)")
+# ============================================================================
+# mFSDP Checkpointing
+# ============================================================================
+
+
+def load_checkpoint_mfsdp(
+ model: torch.nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
+ ckpt_path: str | os.PathLike,
+ dist_config: DistributedConfig,
+ dataloader: StatefulDataLoader | None = None,
+) -> CheckpointOutput:
+ """Load mFSDP distributed checkpoint.
+
+ Args:
+ model: The model to load.
+ optimizer: The optimizer to load.
+ scheduler: The LR scheduler to load.
+ ckpt_path: The directory containing checkpoints.
+ dist_config: The distributed configuration.
+ dataloader: The dataloader to load.
+
+ Returns:
+ Tuple of (model, optimizer, scheduler, dataloader, step, epoch).
+ """
+ checkpoint_path, step = get_latest_checkpoint(ckpt_path)
+ if not checkpoint_path:
+ logger.info("No mFSDP checkpoint found, starting from scratch")
+ return CheckpointOutput(model, optimizer, scheduler, dataloader, 0, 0)
+
+ ckpt_state_dict = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "metadata": {
+ "step": step, # Initialize with current step from filename
+ "epoch": 0, # Initialize with default epoch
+ },
+ }
+ torch.distributed.checkpoint.load(state_dict=ckpt_state_dict, checkpoint_id=checkpoint_path)
+
+ model.load_state_dict(ckpt_state_dict["model"], strict=False)
+ optimizer.load_state_dict(ckpt_state_dict["optimizer"])
+ scheduler.load_state_dict(ckpt_state_dict["scheduler"])
+ dataloader = load_dataloader(dataloader, checkpoint_path, dist_config)
+
+ step = ckpt_state_dict["metadata"]["step"]
+ epoch = ckpt_state_dict["metadata"]["epoch"]
+
+ # Ensure all ranks have completed loading before proceeding
+ torch.distributed.barrier()
+
+ logger.info(f"Loaded mFSDP checkpoint from step {step}")
+
+ # Increment the step by one to avoid re-running the previous step.
+ return CheckpointOutput(model, optimizer, scheduler, dataloader, step + 1, epoch)
+
+
+def save_checkpoint_mfsdp(
+ model: torch.nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
+ ckpt_path: str | os.PathLike,
+ step: int,
+ epoch: int,
+ dist_config: DistributedConfig,
+ dataloader: StatefulDataLoader | None = None,
+ max_checkpoints: int | None = None,
+) -> None:
+ """Save mFSDP distributed checkpoint.
+
+ Args:
+ model: The model to save.
+ optimizer: The optimizer to save.
+ scheduler: The LR scheduler to save.
+ ckpt_path: The directory to save the checkpoint.
+ step: The step number to save the checkpoint.
+ epoch: The epoch number to save the checkpoint.
+ dist_config: The distributed configuration.
+ dataloader: The dataloader to save.
+ max_checkpoints: The maximum number of checkpoints to keep.
+ """
+ ckpt_path = Path(ckpt_path)
+ checkpoint_path = ckpt_path / f"step_{step}"
+ checkpoint_path.mkdir(parents=True, exist_ok=True)
+
+ # Save dataloader state, if provided.
+ save_dataloader(dataloader, checkpoint_path, dist_config)
+
+ # Save model, optimizer, scheduler state, and metadata
+ state_dict = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scheduler": scheduler.state_dict(),
+ "metadata": {
+ "step": step,
+ "epoch": epoch,
+ },
+ }
+
+ torch.distributed.checkpoint.save(state_dict, checkpoint_id=checkpoint_path)
+
+ if dist_config.is_main_process():
+ logger.info(f"Saved mFSDP checkpoint to {checkpoint_path}")
+
+ if max_checkpoints is not None and dist_config.is_main_process():
+ prune_checkpoints(ckpt_path, max_checkpoints)
+
+
+def save_final_model_mfsdp(
+ model: torch.nn.Module,
+ save_directory: str | os.PathLike,
+ dist_config: DistributedConfig,
+) -> None:
+ """Save final model for mFSDP - requires parameter gathering on all ranks."""
+ from megatron_fsdp.uneven_dtensor import gather_uneven_dtensor_to_full_tensor
+
+ if dist_config.is_main_process():
+ logger.info("Starting mFSDP parameter gathering...")
+
+ # Parameter gathering must happen on ALL processes
+ unsharded_state_dict = {
+ # Gather all parameters to CPU, and remove the "module." prefix from the Megatron-FSDP class wrapper.
+ k.removeprefix("module."): gather_uneven_dtensor_to_full_tensor(
+ v, target_device=torch.device("cpu")
+ ).to_local()
+ if isinstance(v, torch.distributed.tensor.DTensor)
+ else v
+ for k, v in model.state_dict().items()
+ }
+
+ # Only main process saves the model
+ if not dist_config.is_main_process():
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+ model.module.save_pretrained(save_directory, state_dict=unsharded_state_dict, safe_serialization=True)
+ logger.info(f"Saved final mFSDP model to {save_directory}")
+
+
# ============================================================================
# Dataloader Checkpointing
# ============================================================================
@@ -439,7 +584,7 @@ def load_dataloader(
)
return dataloader
- dataloader_state = torch.load(dataloader_path)
+ dataloader_state = torch.load(dataloader_path, weights_only=True)
if (
dataloader.num_workers != dataloader_state["num_workers"]
diff --git a/bionemo-recipes/recipes/llama3_native_te/collator.py b/bionemo-recipes/recipes/llama3_native_te/collator.py
index 63b4716410..dcf998c797 100644
--- a/bionemo-recipes/recipes/llama3_native_te/collator.py
+++ b/bionemo-recipes/recipes/llama3_native_te/collator.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Data collator for THD input format tests.
+"""Data collators for sequence packing and context parallel training.
This should eventually get moved to a separate package, or possibly upstreamed into `transformers`.
"""
@@ -674,6 +674,16 @@ def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tupl
def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False):
+ """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths.
+
+ Args:
+ features: List of tokenized samples, each containing at least ``input_ids``.
+ return_position_ids: Whether to return position ids for each token.
+
+ Returns:
+ A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and
+ ``max_length_q``/``max_length_k``.
+ """
is_labels_provided = "labels" in features[0]
sample_lengths = [len(sample["input_ids"]) for sample in features]
@@ -920,7 +930,7 @@ def process_tensor_bshd(val):
class BatchType(TypedDict):
- """The fields in the batch dictionary fo THD context parallel."""
+ """The fields in the batch dictionary for THD context parallel."""
input_ids: torch.Tensor
labels: torch.Tensor | None
diff --git a/bionemo-recipes/recipes/llama3_native_te/dataset.py b/bionemo-recipes/recipes/llama3_native_te/dataset.py
index 6c0b47cf68..3964056d2f 100644
--- a/bionemo-recipes/recipes/llama3_native_te/dataset.py
+++ b/bionemo-recipes/recipes/llama3_native_te/dataset.py
@@ -239,7 +239,6 @@ def create_thd_dataloader(
prefetch_factor: The prefetch factor to use for the dataloader.
max_seq_length: The maximum length of sequences (window size).
stride: The stride for windowing (overlap = stride tokens).
- seed: The seed to use for the distributed sampler and data collator.
buffer_size: The buffer size for shuffle.
use_stateful_dataloader: Whether to use the StatefulDataLoader to enable checkpointing the dataloader state.
text_column: Name of the column containing genomic sequences (default: "text").
diff --git a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py
index 6cd452ebb8..d01024f04c 100644
--- a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py
+++ b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/bionemo-recipes/recipes/llama3_native_te/genomic_dataset.py b/bionemo-recipes/recipes/llama3_native_te/genomic_dataset.py
index 12d6bad16f..0e35550c65 100644
--- a/bionemo-recipes/recipes/llama3_native_te/genomic_dataset.py
+++ b/bionemo-recipes/recipes/llama3_native_te/genomic_dataset.py
@@ -17,13 +17,12 @@
Core functions for genomic data preprocessing during training:
- make_upper_case: Convert lowercase tokens to uppercase
-- Evo2MaskingConstants: Standard DNA tokens and control characters
Adapted from NeMo's Evo2 implementation.
"""
from dataclasses import dataclass
-from typing import Any, ClassVar
+from typing import Any
import torch
@@ -47,16 +46,6 @@ def _make_upper_case(tokens, lowercase_start=97, lowercase_end=122, case_diff=32
return uppercase_tensor, lowercase_mask
-class Evo2MaskingConstants:
- """Constants used in Evo2 genomic sequence masking."""
-
- # Standard DNA tokens: A, C, G, T (both uppercase and lowercase)
- DNA_TOKENS: ClassVar[list[int]] = [65, 67, 71, 84, 97, 99, 103, 116]
-
- # Control characters used in data formatting
- CONTROL_TAGS: ClassVar[list[int]] = [64, 35] # '@', '#'
-
-
@dataclass
class GenomicDataCollator:
"""Wrapper collator that adds genomic-specific masking to any base collator.
diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
index d6c181598f..c50909a95b 100644
--- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
+++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
@@ -75,6 +75,19 @@ fp8_stats_config:
fp8_stats_file: ./fp8_debugging_stats.yaml
fp8_log_dir: ./log_fp8_stats
+# mFSDP config
+fully_shard_kwargs:
+ zero_dp_strategy: "optim_grads_params"
+ calculate_per_token_loss: false
+ init_model_with_meta_device: ${use_meta_device}
+ check_for_nan_in_grad: true
+ grad_reduce_in_fp32: false
+ preserve_fp32_weights: true
+ overlap_grad_reduce: true
+ overlap_param_gather: true
+ sync_model_each_microbatch: true
+ average_in_collective: false
+
profiler:
enabled: false
start_step: 10
diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
index d4821be1f7..433edd084a 100644
--- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
+++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
@@ -62,7 +62,7 @@ def init_empty_weights(self):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
- # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use
+ # The embed_tokens layer is the only non-TE layer in this model we need to deal with. We use
# `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard
# deviation.
self.model.embed_tokens.to_empty(device="cuda")
@@ -363,9 +363,10 @@ def forward(
)
-class NVLlamaForSequenceClassification( # noqa: D101
+class NVLlamaForSequenceClassification(
transformers.modeling_layers.GenericForSequenceClassification, NVLlamaPreTrainedModel
-): ...
+):
+ """Llama3 model with sequence classification head."""
class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestionAnswering, NVLlamaPreTrainedModel):
@@ -374,9 +375,10 @@ class NVLlamaForQuestionAnswering(transformers.modeling_layers.GenericForQuestio
base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
-class NVLlamaForTokenClassification( # noqa: D101
+class NVLlamaForTokenClassification(
transformers.modeling_layers.GenericForTokenClassification, NVLlamaPreTrainedModel
-): ...
+):
+ """Llama3 model with token classification head."""
torch._dynamo.config.capture_scalar_outputs = True
diff --git a/bionemo-recipes/recipes/llama3_native_te/requirements.txt b/bionemo-recipes/recipes/llama3_native_te/requirements.txt
index 073d9b39e3..672a3b7e1b 100644
--- a/bionemo-recipes/recipes/llama3_native_te/requirements.txt
+++ b/bionemo-recipes/recipes/llama3_native_te/requirements.txt
@@ -1,5 +1,6 @@
datasets
hydra-core
+megatron-fsdp
torch
torchao!=0.14.0
torchdata
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_genomic_dataset.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_genomic_dataset.py
index 0aa01c260c..f23265e719 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_genomic_dataset.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_genomic_dataset.py
@@ -29,7 +29,7 @@ def tokenizer(tokenizer_path):
return AutoTokenizer.from_pretrained(tokenizer_path)
-# Tests for GenomicDataCollatorForCLM
+# Tests for GenomicDataCollator
def test_collator_basic(tokenizer):
"""Test basic collator functionality."""
base = DataCollatorForLanguageModeling(tokenizer, mlm=False)
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
index 0fb725092a..effffbadf4 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
@@ -25,6 +25,7 @@
from train_ddp import main as main_ddp
from train_fsdp2 import main as main_fsdp2
from train_fsdp2_cp import main as main_fsdp2_cp
+from train_mfsdp_cp import main as main_mfsdp_cp
# TODO(@jomitchell): Delete once https://nvbugspro.nvidia.com/bug/5458694 is fixed.
@@ -420,7 +421,7 @@ def test_train_fsdp2_fp8_thd(tmp_path, recipe_path):
@requires_datacenter_hardware
def test_sanity_fsdp2_cp(tmp_path, recipe_path):
- # Run the training script with Hydra configuration overrides
+ """Test FSDP2 with context parallelism training on a single GPU (cp_size=1)."""
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
sanity_config = compose(
config_name="L0_sanity_cp",
@@ -439,6 +440,49 @@ def test_sanity_fsdp2_cp(tmp_path, recipe_path):
assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite"
+def test_sanity_convergence_mfsdp_cp_bshd(tmp_path, recipe_path):
+ """Test Megatron-FSDP with context parallelism training on a single GPU (cp_size=1) using BSHD format."""
+ with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
+ sanity_config = compose(
+ config_name="L0_sanity_cp",
+ overrides=[
+ f"+wandb.dir={tmp_path}",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "checkpoint.resume_from_checkpoint=false",
+ "config_kwargs.attn_input_format=bshd",
+ "config_kwargs.self_attn_mask_type=causal",
+ ],
+ )
+
+ final_loss = main_mfsdp_cp(sanity_config)
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0"
+
+
+@requires_datacenter_hardware
+def test_sanity_convergence_mfsdp_cp_thd(tmp_path, recipe_path):
+ """Test Megatron-FSDP with context parallelism training on a single GPU (cp_size=1) using THD format."""
+ with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
+ sanity_config = compose(
+ config_name="L0_sanity_cp",
+ overrides=[
+ f"+wandb.dir={tmp_path}",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "checkpoint.resume_from_checkpoint=false",
+ "config_kwargs.attn_input_format=thd",
+ "config_kwargs.self_attn_mask_type=padding_causal",
+ ],
+ )
+
+ final_loss = main_mfsdp_cp(sanity_config)
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0"
+
+
@requires_fp8
def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path):
"""Test that FP8 stats logging creates the expected log files."""
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py
index d64e6e4eed..0d1e4f8e6b 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py
@@ -90,12 +90,11 @@ def test_multi_gpu_train_ddp(recipe_path):
"torchrun",
"--standalone",
"--nproc_per_node",
- "2", # 2 processes = 2 GPUs
- "--standalone", # Single node mode
+ "2",
"train_ddp.py",
"--config-name",
"L0_sanity",
- "num_train_steps=4", # Just 4 steps for speed
+ "num_train_steps=4",
],
recipe_path,
)
@@ -118,12 +117,11 @@ def test_multi_gpu_train_fsdp2(recipe_path):
"torchrun",
"--standalone",
"--nproc_per_node",
- "2", # 2 processes = 2 GPUs
- "--standalone", # Single node mode
+ "2",
"train_fsdp2.py",
"--config-name",
"L0_sanity",
- "num_train_steps=4", # Just 4 steps for speed
+ "num_train_steps=4",
],
recipe_path,
)
@@ -144,7 +142,6 @@ def test_multi_gpu_train_ddp_with_checkpointing(tmp_path, recipe_path):
"--standalone",
"--nproc_per_node",
"2",
- "--standalone",
"train_ddp.py",
"--config-name",
"L0_sanity",
@@ -177,7 +174,6 @@ def test_multi_gpu_train_fsdp2_with_checkpointing(tmp_path, recipe_path):
"--standalone",
"--nproc_per_node",
"2",
- "--standalone",
"train_fsdp2.py",
"--config-name",
"L0_sanity",
@@ -197,12 +193,12 @@ def test_multi_gpu_train_fsdp2_with_checkpointing(tmp_path, recipe_path):
@requires_multi_gpu
def test_multi_gpu_train_te_fsdp2_cp_bshd(tmp_path, recipe_path):
+ """Test FSDP2 with context parallelism on 2 GPUs using BSHD input format."""
run_train_cmd(
[
"torchrun",
"--standalone",
"--nproc_per_node=2",
- "--standalone",
"train_fsdp2_cp.py",
"--config-name",
"L0_sanity_cp",
@@ -221,12 +217,12 @@ def test_multi_gpu_train_te_fsdp2_cp_bshd(tmp_path, recipe_path):
@requires_multi_gpu
@requires_datacenter_hardware
def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path):
+ """Test FSDP2 with context parallelism on 2 GPUs using THD input format."""
run_train_cmd(
[
"torchrun",
"--standalone",
"--nproc_per_node=2",
- "--standalone",
"train_fsdp2_cp.py",
"--config-name",
"L0_sanity_cp",
@@ -242,6 +238,76 @@ def test_multi_gpu_train_te_fsdp2_cp_thd(tmp_path, recipe_path):
)
+@requires_multi_gpu
+def test_multi_gpu_train_mfsdp_cp(tmp_path, recipe_path):
+ """Test Megatron-FSDP with context parallelism on 2 GPUs (cp_size=1, dp=2)."""
+ run_train_cmd(
+ [
+ "torchrun",
+ "--standalone",
+ "--nproc_per_node=2",
+ "train_mfsdp_cp.py",
+ "--config-name",
+ "L0_sanity_cp",
+ "num_train_steps=10",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "checkpoint.save_every_n_steps=5",
+ "cp_size=1",
+ "use_sequence_packing=false",
+ "config_kwargs.attn_input_format=bshd",
+ "config_kwargs.self_attn_mask_type=causal",
+ ],
+ recipe_path,
+ )
+
+
+@requires_multi_gpu
+def test_multi_gpu_train_mfsdp_cp_bshd(tmp_path, recipe_path):
+ """Test Megatron-FSDP with context parallelism on 2 GPUs (cp_size=2) using BSHD format."""
+ run_train_cmd(
+ [
+ "torchrun",
+ "--standalone",
+ "--nproc_per_node=2",
+ "train_mfsdp_cp.py",
+ "--config-name",
+ "L0_sanity_cp",
+ "num_train_steps=10",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "checkpoint.save_every_n_steps=5",
+ "cp_size=2",
+ "use_sequence_packing=false",
+ "config_kwargs.attn_input_format=bshd",
+ "config_kwargs.self_attn_mask_type=causal",
+ ],
+ recipe_path,
+ )
+
+
+@requires_multi_gpu
+@requires_datacenter_hardware
+def test_multi_gpu_train_mfsdp_cp_thd(tmp_path, recipe_path):
+ """Test Megatron-FSDP with context parallelism on 2 GPUs (cp_size=2) using THD format."""
+ run_train_cmd(
+ [
+ "torchrun",
+ "--standalone",
+ "--nproc_per_node=2",
+ "train_mfsdp_cp.py",
+ "--config-name",
+ "L0_sanity_cp",
+ "num_train_steps=10",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "checkpoint.save_every_n_steps=5",
+ "cp_size=2",
+ "use_sequence_packing=true",
+ "config_kwargs.attn_input_format=thd",
+ "config_kwargs.self_attn_mask_type=padding_causal",
+ ],
+ recipe_path,
+ )
+
+
nsys_available = subprocess.run(["which", "nsys"], check=False, capture_output=True).returncode == 0
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py
index 49d22901d3..7aed3ff6f5 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py
@@ -13,6 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""Distributed Data Parallel (DDP) training script for Llama 3 with TransformerEngine.
+
+Each GPU holds a full copy of the model and gradients are synchronized via all-reduce after each
+backward pass. This is the simplest distributed strategy and works well for smaller models that fit
+in a single GPU's memory. Supports both TE-accelerated (NVLlamaForCausalLM) and standard HuggingFace
+(LlamaForCausalLM) models.
+
+For large models that do not fit on a single GPU, use ``train_fsdp2.py`` instead.
+"""
+
+import gc
import logging
from contextlib import nullcontext
from pathlib import Path
@@ -22,7 +33,7 @@
import torch
import transformer_engine
import transformer_engine.pytorch
-from omegaconf import DictConfig
+from omegaconf import DictConfig, OmegaConf
from torch.distributed.device_mesh import init_device_mesh
from torch.optim import AdamW
from transformer_engine.common.recipe import Format
@@ -49,7 +60,7 @@ def main(args: DictConfig) -> float | None:
Returns:
float: The loss value for the final batch.
"""
- # Initialize the distributed configuration, including creating the distributed process group.
+ # --- Distributed Setup ---
dist_config = DistributedConfig()
logger.info("Initializing distributed training: %s", dist_config)
device = torch.device(f"cuda:{dist_config.local_rank}")
@@ -60,11 +71,10 @@ def main(args: DictConfig) -> float | None:
if args.fp8_stats_config.enabled:
initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled)
- # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2
- # and MFSDP.
+ # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2.
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
- # Create an FP8 recipe -- this is only used if FP8 is enabled in the config.
+ # --- Model Configuration ---
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
)
@@ -76,12 +86,12 @@ def main(args: DictConfig) -> float | None:
config_class = LlamaConfig
model_class = LlamaForCausalLM
- # Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B".
+ # --- Model Initialization ---
config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
# Optionally use transformer engine to initialize only fp8 versions of weights by setting
- # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8
- # versions of weights are kept.
+ # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
+ # and fp8 versions of weights are kept.
with transformer_engine.pytorch.quantized_model_init(
recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs
):
@@ -89,10 +99,7 @@ def main(args: DictConfig) -> float | None:
logger.info("Initialized Model:\n%s", model)
- # Create optimizer.
- optimizer = AdamW(model.parameters(), **args.adamw_kwargs)
- scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
-
+ # --- Distributed Wrapping (DDP) ---
if args.fp8_stats_config.enabled:
debug_api.infer_and_assign_layer_names(model)
@@ -104,18 +111,25 @@ def main(args: DictConfig) -> float | None:
device_mesh=device_mesh["dp"],
)
- if args.use_sequence_packing:
- train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
- else:
- train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset)
+ # --- Optimizer & Scheduler ---
+ # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
+ optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
+ scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
if args.use_torch_compile:
# If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
model = torch.compile(model)
- # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0.
+ # --- Data Loading ---
+ if args.use_sequence_packing:
+ train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
+ else:
+ train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset)
+
+ # --- Checkpoint Resume ---
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_ddp" if args.checkpoint.ckpt_dir else None
if args.checkpoint.resume_from_checkpoint and ckpt_path:
+ logger.info("Attempting to load checkpoint from %s", ckpt_path)
model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_ddp(
model=model,
optimizer=optimizer,
@@ -124,22 +138,27 @@ def main(args: DictConfig) -> float | None:
dist_config=dist_config,
dataloader=train_dataloader,
)
+ logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch)
else:
+ logger.info("No checkpoint to load, starting from scratch")
start_step = 0
epoch = 0
perf_logger = PerfLogger(dist_config, args)
- # Training loop
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # --- Training Loop ---
+ logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps)
step = start_step
micro_step = 0 # Gradient accumulation step counter
while step < args.num_train_steps:
for batch in train_dataloader:
- print(batch["input_ids"].shape)
- batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa PLW2901
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901
micro_step += 1
- # Use no_sync to prevent gradient synchronization until the last microbatch
+ # DDP requires no_sync to skip all-reduce until the last microbatch in the accumulation window.
with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext():
# Forward pass with mixed precision.
with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe):
@@ -157,7 +176,7 @@ def main(args: DictConfig) -> float | None:
micro_step = 0
# Compute and clip gradient norms.
- total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
+ total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step optimizer.
optimizer.step()
@@ -191,7 +210,7 @@ def main(args: DictConfig) -> float | None:
epoch += 1
dataset_or_sampler.set_epoch(epoch)
- # Save final model to a .safetensors file.
+ # --- Cleanup ---
if args.checkpoint.save_final_model and ckpt_path:
save_final_model_ddp(
model=model,
@@ -199,7 +218,6 @@ def main(args: DictConfig) -> float | None:
dist_config=dist_config,
)
- # Clean up distributed training
perf_logger.finish()
torch.distributed.destroy_process_group()
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
index be3e80a514..558d27366d 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
@@ -13,6 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""Fully Sharded Data Parallel v2 (FSDP2) training script for Llama 3 with TransformerEngine.
+
+Model weights and optimizer states are sharded across GPUs, allowing training of models that exceed
+the memory of a single GPU. Supports both TE-accelerated (NVLlamaForCausalLM) and standard
+HuggingFace (LlamaForCausalLM) models.
+
+For very long sequences, use ``train_fsdp2_cp.py`` which adds Context Parallelism on top of FSDP2.
+"""
+
import gc
import logging
from contextlib import nullcontext
@@ -57,7 +66,7 @@ def main(args: DictConfig) -> float | None:
Returns:
float: The loss value for the final batch.
"""
- # Initialize the distributed configuration, including creating the distributed process group.
+ # --- Distributed Setup ---
dist_config = DistributedConfig()
logger.info("Initializing distributed training: %s", dist_config)
device = torch.device(f"cuda:{dist_config.local_rank}")
@@ -68,10 +77,9 @@ def main(args: DictConfig) -> float | None:
if args.fp8_stats_config.enabled:
initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled)
- # Create a device mesh for FSDP.
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
- # Create an FP8 recipe -- this is only used if FP8 is enabled in the config.
+ # --- Model Configuration ---
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
)
@@ -83,12 +91,12 @@ def main(args: DictConfig) -> float | None:
config_class = LlamaConfig
model_class = LlamaForCausalLM
- # Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B".
+ # --- Model Initialization ---
config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
# Optionally use transformer engine to initialize only fp8 versions of weights by setting
- # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8
- # versions of weights are kept.
+ # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
+ # and fp8 versions of weights are kept.
with (
torch.device("meta") if args.use_meta_device else nullcontext(),
transformer_engine.pytorch.quantized_model_init(
@@ -99,7 +107,7 @@ def main(args: DictConfig) -> float | None:
logger.info("Initialized Model:\n%s", model)
- # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
+ # --- Distributed Wrapping (FSDP2) ---
# Each decoder layer should be individually sharded before sharding the full model.
for layer in model.model.layers:
fully_shard(layer, mesh=device_mesh["dp"])
@@ -118,23 +126,25 @@ def main(args: DictConfig) -> float | None:
if args.fp8_stats_config.enabled:
debug_api.infer_and_assign_layer_names(model)
- # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
+ # --- Optimizer & Scheduler ---
+ # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
+ if args.use_torch_compile:
+ # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
+ model = torch.compile(model)
+
+ # --- Data Loading ---
if args.use_sequence_packing:
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
else:
train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset)
- if args.use_torch_compile:
- # If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
- model = torch.compile(model)
-
- # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0.
+ # --- Checkpoint Resume ---
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None
if args.checkpoint.resume_from_checkpoint and ckpt_path:
- logger.info(f"Attempting to load checkpoint from {ckpt_path}")
+ logger.info("Attempting to load checkpoint from %s", ckpt_path)
model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_fsdp2(
model=model,
optimizer=optimizer,
@@ -144,7 +154,7 @@ def main(args: DictConfig) -> float | None:
dataloader=train_dataloader,
process_group=device_mesh.get_group("dp"),
)
- logger.info(f"Checkpoint loaded, resuming from step {start_step}, epoch {epoch}")
+ logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch)
else:
logger.info("No checkpoint to load, starting from scratch")
start_step = 0
@@ -155,8 +165,8 @@ def main(args: DictConfig) -> float | None:
gc.collect()
torch.cuda.empty_cache()
- # Training loop
- logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}")
+ # --- Training Loop ---
+ logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps)
step = start_step
micro_step = 0 # Gradient accumulation step counter
while step < args.num_train_steps:
@@ -176,7 +186,7 @@ def main(args: DictConfig) -> float | None:
# Log microbatch step data for accumulation metrics
perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs)
- # Gradient accumulation - only step optimizer after accumulating gradients
+ # The end of a "full" step (i.e. after possibly multiple gradient accumulation steps).
if micro_step % args.grad_acc_steps == 0:
micro_step = 0
@@ -217,7 +227,7 @@ def main(args: DictConfig) -> float | None:
epoch += 1
dataset_or_sampler.set_epoch(epoch)
- # Save final model to a .safetensors file.
+ # --- Cleanup ---
if args.checkpoint.save_final_model and ckpt_path:
save_final_model_fsdp2(
model=model,
@@ -229,7 +239,6 @@ def main(args: DictConfig) -> float | None:
if args.checkpoint.async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None:
_ckpt_futures["fsdp2"].result()
- # Clean up distributed training
perf_logger.finish()
torch.distributed.destroy_process_group()
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
index 742e21a63d..9ad3d0e297 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
@@ -13,6 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""FSDP2 with Context Parallelism training script for Llama 3 with TransformerEngine.
+
+Combines Fully Sharded Data Parallel v2 with Context Parallelism (CP), where each sequence is
+split across multiple GPUs along the sequence dimension. This is useful for training with very long
+sequences that do not fit into a single GPU's memory even with FSDP2 alone. Only supports
+TE-accelerated models (NVLlamaForCausalLM).
+
+For standard FSDP2 training without context parallelism, use ``train_fsdp2.py`` instead.
+"""
+
import gc
import logging
from contextlib import nullcontext
@@ -28,7 +38,13 @@
from torch.optim import AdamW
from transformer_engine.common.recipe import Format
-from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint
+from checkpoint import (
+ _ckpt_futures,
+ load_checkpoint_fsdp2,
+ save_checkpoint_fsdp2,
+ save_final_model_fsdp2,
+ should_save_checkpoint,
+)
from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel
from dataset import create_bshd_dataloader, create_thd_dataloader
from distributed_config import DistributedConfig
@@ -43,37 +59,36 @@
@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2")
def main(args: DictConfig) -> float | None:
- """Train Llama3 with TE layers using FSDP2.
+ """Train Llama3 with TE layers using FSDP2 with Context Parallelism.
Returns:
float: The loss value for the final batch.
"""
- # Initialize the distributed configuration, including creating the distributed process group.
+ # --- Distributed Setup ---
dist_config = DistributedConfig()
logger.info("Initializing distributed training: %s", dist_config)
device = torch.device(f"cuda:{dist_config.local_rank}")
torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device)
torch.cuda.set_device(dist_config.local_rank)
- # Create a device mesh for FSDP.
device_mesh = init_device_mesh(
"cuda",
mesh_shape=(dist_config.world_size // args.cp_size, args.cp_size),
mesh_dim_names=("dp", "cp"),
)
- logger.info(f"Created device mesh: {device_mesh}")
+ logger.info("Created device mesh: %s", device_mesh)
- # Create an FP8 recipe -- this is only used if FP8 is enabled in the config.
+ # --- Model Configuration ---
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
)
- # Create an empty Llama3 model with a causal language model head, e.g. "meta-llama/Meta-Llama-3-8B".
+ # --- Model Initialization ---
config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
# Optionally use transformer engine to initialize only fp8 versions of weights by setting
- # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 and fp8
- # versions of weights are kept.
+ # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
+ # and fp8 versions of weights are kept.
with (
torch.device("meta") if args.use_meta_device else nullcontext(),
transformer_engine.pytorch.quantized_model_init(
@@ -84,7 +99,7 @@ def main(args: DictConfig) -> float | None:
logger.info("Initialized Model:\n%s", model)
- # Create a flattened mesh for FSDP2 sharding. This will shard the model across both the DP and CP ranks.
+ # --- Distributed Wrapping (FSDP2 + CP) ---
cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp")
# Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
@@ -105,7 +120,8 @@ def main(args: DictConfig) -> float | None:
# TE layers require special handling to initialize the weights from the meta device.
model.init_empty_weights()
- # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
+ # --- Optimizer & Scheduler ---
+ # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
@@ -113,11 +129,16 @@ def main(args: DictConfig) -> float | None:
# If we're using torch.compile, we need to do this before loading the checkpoint to ensure key consistency.
model = torch.compile(model)
- # Create the context-aware dataloader. We only create the dataloader on rank 0 and wrap it in a
- # ContextParallelDataLoaderWrapper that will shard and distribute the data across the context parallelism group.
+ # --- Data Loading ---
+ # Create the context-aware dataloader.
if args.dataset.get("pad_sequences_to_be_divisible_by", None) is None:
+ # The dual chunk algorithm gives each CP rank 2 chunks from each sequence, so we need each sequence to be
+ # divisible by cp_mesh.size() * 2.
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2)
+
+ # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP)
+ # ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
if device_mesh["cp"].get_local_rank() == 0:
if args.use_sequence_packing:
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
@@ -135,12 +156,13 @@ def main(args: DictConfig) -> float | None:
train_dataloader = None
dataset_or_sampler = None
+ # On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0.
train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"])
- # If we're resuming from a checkpoint, load it and set the start step. Otherwise, start from step 0.
+ # --- Checkpoint Resume ---
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None
if args.checkpoint.resume_from_checkpoint and ckpt_path:
- logger.info(f"Attempting to load checkpoint from {ckpt_path}")
+ logger.info("Attempting to load checkpoint from %s", ckpt_path)
model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_fsdp2(
model=model,
optimizer=optimizer,
@@ -150,7 +172,7 @@ def main(args: DictConfig) -> float | None:
dataloader=train_dataloader,
process_group=cp_dp_mesh.get_group(),
)
- logger.info(f"Checkpoint loaded, resuming from step {start_step}, epoch {epoch}")
+ logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch)
else:
logger.info("No checkpoint to load, starting from scratch")
start_step = 0
@@ -161,8 +183,8 @@ def main(args: DictConfig) -> float | None:
gc.collect()
torch.cuda.empty_cache()
- # Training loop
- logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}")
+ # --- Training Loop ---
+ logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps)
step = start_step
micro_step = 0 # Gradient accumulation step counter
while step < args.num_train_steps:
@@ -185,7 +207,7 @@ def main(args: DictConfig) -> float | None:
# Log microbatch step data for accumulation metrics
perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs)
- # Gradient accumulation - only step optimizer after accumulating gradients
+ # The end of a "full" step (i.e. after possibly multiple gradient accumulation steps).
if micro_step % args.grad_acc_steps == 0:
micro_step = 0
@@ -227,7 +249,7 @@ def main(args: DictConfig) -> float | None:
if dataset_or_sampler is not None: # The dataset only exists on rank 0
dataset_or_sampler.set_epoch(epoch)
- # Save final model to a .safetensors file.
+ # --- Cleanup ---
if args.checkpoint.save_final_model and ckpt_path:
save_final_model_fsdp2(
model=model,
@@ -235,7 +257,10 @@ def main(args: DictConfig) -> float | None:
dist_config=dist_config,
)
- # Clean up distributed training
+ # Make sure we don't have any outstanding checkpoint save futures.
+ if args.checkpoint.async_save and "fsdp2" in _ckpt_futures and _ckpt_futures["fsdp2"] is not None:
+ _ckpt_futures["fsdp2"].result()
+
perf_logger.finish()
torch.distributed.destroy_process_group()
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_mfsdp_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_mfsdp_cp.py
new file mode 100644
index 0000000000..742f7cff01
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/train_mfsdp_cp.py
@@ -0,0 +1,281 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Megatron-FSDP with Context Parallelism training script for Llama 3 with TransformerEngine.
+
+Combines Megatron-FSDP with Context Parallelism (CP), where each sequence is split across multiple
+GPUs along the sequence dimension. This is useful for training with very long sequences that do not
+fit into a single GPU's memory even with FSDP alone. Only supports TE-accelerated models
+(NVLlamaForCausalLM).
+
+For standard FSDP2 training without context parallelism, use ``train_fsdp2.py`` instead.
+For FSDP2 with context parallelism, use ``train_fsdp2_cp.py`` instead.
+"""
+
+import gc
+import logging
+from pathlib import Path
+
+import hydra
+import nvtx
+import torch
+import transformer_engine.pytorch
+from megatron_fsdp.fully_shard import fully_shard as mfsdp_fully_shard
+from omegaconf import DictConfig, OmegaConf
+from torch.distributed.device_mesh import init_device_mesh
+from torch.optim import AdamW
+from transformer_engine.common.recipe import Format
+
+from checkpoint import (
+ load_checkpoint_mfsdp,
+ save_checkpoint_mfsdp,
+ save_final_model_mfsdp,
+ should_save_checkpoint,
+)
+from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel
+from dataset import create_bshd_dataloader, create_thd_dataloader
+from distributed_config import DistributedConfig
+from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
+from perf_logger import PerfLogger
+from scheduler import get_cosine_annealing_schedule_with_warmup
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2")
+def main(args: DictConfig) -> float | None:
+ """Train Llama3 with TE layers using Megatron-FSDP with Context Parallelism.
+
+ Returns:
+ float: The loss value for the final batch.
+ """
+ # --- Distributed Setup ---
+ dist_config = DistributedConfig()
+ logger.info("Initializing distributed training: %s", dist_config)
+ device = torch.device(f"cuda:{dist_config.local_rank}")
+ torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device)
+ torch.cuda.set_device(dist_config.local_rank)
+
+ # Create a 3D device mesh (dp, cp, tp) where tp is a dummy dimension of size 1 required by mfsdp.
+ dp_size = dist_config.world_size // args.cp_size
+ device_mesh = init_device_mesh(
+ "cuda",
+ mesh_shape=(dp_size, args.cp_size, 1),
+ mesh_dim_names=("dp", "cp", "tp"),
+ )
+ logger.info("Created device mesh: %s", device_mesh)
+
+ # --- Model Configuration ---
+ fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
+ fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
+ )
+
+ # --- Model Initialization ---
+ config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
+
+ # mfsdp does not support tied weight parameters. If tie_word_embeddings is enabled, we need to untie them so that
+ # lm_head.weight and embed_tokens.weight are separate parameters for the mfsdp optimizer buffer.
+ if config.tie_word_embeddings:
+ logger.warning(
+ "Megatron-FSDP does not support tied weight parameters. Setting tie_word_embeddings=False. "
+ "This means lm_head.weight will be a separate parameter from embed_tokens.weight."
+ )
+ config.tie_word_embeddings = False
+
+ # Optionally use transformer engine to initialize only fp8 versions of weights by setting
+ # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
+ # and fp8 versions of weights are kept.
+ # NOTE: Meta device initialization for mfsdp is handled by the `init_model_with_meta_device` kwarg in
+ # fully_shard_kwargs, so we do NOT use `torch.device("meta")` here (unlike train_fsdp2_cp.py).
+ with transformer_engine.pytorch.quantized_model_init(
+ recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs
+ ):
+ model = NVLlamaForCausalLM(config)
+
+ logger.info("Initialized Model:\n%s", model)
+
+ # --- Optimizer (created before mfsdp wrapping, will be wrapped by fully_shard) ---
+ # Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
+ optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
+
+ # --- Distributed Wrapping (Megatron-FSDP + CP) ---
+ model, optimizer = mfsdp_fully_shard(
+ module=model,
+ optimizer=optimizer,
+ fsdp_unit_modules=[
+ transformer_engine.pytorch.TransformerLayer,
+ transformer_engine.pytorch.LayerNorm,
+ transformer_engine.pytorch.LayerNormLinear,
+ ],
+ device_mesh=device_mesh,
+ dp_shard_dim="dp",
+ tp_dim="tp",
+ **args.fully_shard_kwargs,
+ )
+
+ # Attach the CP group to each transformer layer.
+ for layer in model.module.model.layers:
+ layer.set_context_parallel_group(
+ device_mesh["cp"].get_group(),
+ torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()),
+ torch.cuda.Stream(),
+ )
+
+ # --- Scheduler (must be created after mfsdp wrapping since fully_shard modifies the optimizer) ---
+ scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
+
+ if args.use_torch_compile:
+ logger.warning(
+ "BIONEMO-2977: Using torch.compile with mfsdp is currently not supported. `use_torch_compile` was set to "
+ "true, but will be ignored."
+ )
+
+ # --- Data Loading ---
+ # Create the context-aware dataloader.
+ if args.dataset.get("pad_sequences_to_be_divisible_by", None) is None:
+ # The dual chunk algorithm gives each CP rank 2 chunks from each sequence, so we need each sequence to be
+ # divisible by cp_mesh.size() * 2.
+ logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
+ OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2)
+
+ # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP)
+ # ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
+ if device_mesh["cp"].get_local_rank() == 0:
+ if args.use_sequence_packing:
+ train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
+ else:
+ train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset)
+
+ train_dataloader.collate_fn = DataCollatorForContextParallel(
+ collator=train_dataloader.collate_fn,
+ device_mesh=device_mesh,
+ qkv_format=args.config_kwargs.attn_input_format,
+ is_causal_lm=True,
+ )
+
+ else:
+ train_dataloader = None
+ dataset_or_sampler = None
+
+ # On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0.
+ train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"])
+
+ # --- Checkpoint Resume ---
+ ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_mfsdp" if args.checkpoint.ckpt_dir else None
+ if args.checkpoint.resume_from_checkpoint and ckpt_path:
+ logger.info("Attempting to load checkpoint from %s", ckpt_path)
+ model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_mfsdp(
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ ckpt_path=ckpt_path,
+ dist_config=dist_config,
+ dataloader=train_dataloader,
+ )
+ logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch)
+ else:
+ logger.info("No checkpoint to load, starting from scratch")
+ start_step = 0
+ epoch = 0
+
+ perf_logger = PerfLogger(dist_config, args)
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # --- Training Loop ---
+ logger.info("Starting training loop from step %s to %s", start_step, args.num_train_steps)
+ step = start_step
+ micro_step = 0 # Gradient accumulation step counter
+ while step < args.num_train_steps:
+ for batch in train_dataloader:
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901
+
+ micro_step += 1
+
+ # Forward pass with mixed precision.
+ with nvtx.annotate("Forward pass", color="green"):
+ with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe):
+ outputs = model(**batch)
+
+ # Backward pass - scale loss by grad_acc_steps for proper gradient averaging
+ loss = outputs.loss / args.grad_acc_steps
+
+ with nvtx.annotate("Backward pass", color="red"):
+ loss.backward()
+
+ # Log microbatch step data for accumulation metrics
+ perf_logger.log_micro_step(step=step, batch=batch, outputs=outputs)
+
+ # The end of a "full" step (i.e. after possibly multiple gradient accumulation steps).
+ if micro_step % args.grad_acc_steps == 0:
+ micro_step = 0
+
+ # Compute and clip gradient norms.
+ # NOTE: grad clipping with mfsdp has been reported to cause hangs in some configurations.
+ # If you experience hangs, try commenting out the clip_grad_norm_ call.
+ total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
+
+ # Step optimizer.
+ optimizer.step()
+ scheduler.step()
+ optimizer.zero_grad()
+
+ perf_logger.log_step(
+ step=step,
+ grad_norm=total_norm,
+ lr=optimizer.param_groups[0]["lr"],
+ )
+
+ if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
+ save_checkpoint_mfsdp(
+ model=model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ ckpt_path=ckpt_path,
+ step=step,
+ epoch=epoch,
+ dist_config=dist_config,
+ dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
+ max_checkpoints=args.checkpoint.max_checkpoints,
+ )
+
+ step += 1
+ if step >= args.num_train_steps:
+ break
+
+ # Dataloader exhausted, incrementing epoch
+ epoch += 1
+ if dataset_or_sampler is not None: # The dataset only exists on rank 0
+ dataset_or_sampler.set_epoch(epoch)
+
+ # --- Cleanup ---
+ if args.checkpoint.save_final_model and ckpt_path:
+ save_final_model_mfsdp(
+ model=model,
+ save_directory=ckpt_path / "final_model",
+ dist_config=dist_config,
+ )
+
+ perf_logger.finish()
+ torch.distributed.destroy_process_group()
+
+ return perf_logger.min_loss
+
+
+if __name__ == "__main__":
+ main()