Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions bionemo-recipes/recipes/llama3_native_te/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save
from torch.distributed.checkpoint.state_dict_saver import save as dcp_save
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.tensor import DTensor
from torchdata.stateful_dataloader import StatefulDataLoader

from distributed_config import DistributedConfig
Expand Down Expand Up @@ -219,8 +220,28 @@ class AppState(Stateful):
epoch: int = 0

def state_dict(self):
"""Get the state dict for the model, optimizer, scheduler, and step."""
"""
Get the state dict for the model, optimizer, scheduler, and step.
This factory both retrieves the model state dictionary when saving
checkpoints and initializes a destination for the state read from
DCP checkpoint files when loading checkpoints.
"""
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
for fqn in list(model_state_dict.keys()):
# Get the model parameter.
model_param = model_state_dict[fqn]
if isinstance(model_param, DTensor):
model_param = model_param.to_local()
if model_param.numel() == 0 and fqn in optimizer_state_dict['state']:
# Empty model parameter. Clear the associated optimizer state
# when initializing the optimizer state upon DCP load, because
# empty optimizer state DTensors are not checkpointed with DCP,
# yet get_state_dict / _init_optim_state produce empty Tensors.
# TransformerEngine uses empty Tensors for dummy Parameters.
optimizer_state_dict['state'][fqn] = {}
if fqn.endswith("._extra_state"):
# Evict `_extra_state` quantization data from model checkpoint.
model_state_dict.pop(fqn)
return {
"model": model_state_dict,
"optim": optimizer_state_dict,
Expand All @@ -230,12 +251,19 @@ def state_dict(self):
}

def load_state_dict(self, state_dict: dict):
"""Load the state dict for the model, optimizer, scheduler, and step."""
"""
Load the state dict for the model, optimizer, scheduler, and step.
Given the checkpoint-loaded state_dict, set the state of the model,
optimizer, scheduler, step, and epoch to the values in state_dict.
"""
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
# Non-strict checkpoint loading ignores empty optimizer states,
# skips loading non-FP8 checkpoint weights (e.g. _extra_state).
options=StateDictOptions(strict=False),
)
self.scheduler.load_state_dict(state_dict["scheduler"])
self.step = state_dict["step"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
defaults:
- L0_sanity
- _self_

tp_size: 2
cp_size: 2

dataset:
# CP2 * (8 for FP8 Activations, 16 for FP8 Parameters)
pad_sequences_to_be_divisible_by: 32

fp8_config:
enabled: true
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
fp8_format: "HYBRID"
fp8_recipe_kwargs: {}
quantized_model_init_kwargs:
# TODO(@cspades): Quantized parameters are
# NOT supported with DCP checkpointing.
enabled: true

checkpoint:
ckpt_dir: ./fsdp_tp_ckpts
save_final_model: true

config_kwargs:
attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware.
self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs.
tensor_parallel: true # Tensor Parallelism for TE
sequence_parallel: true # Sequence parallelism for LayerNorm on TP ranks.
tp_size: ${tp_size} # Tensor Parallel Size
67 changes: 66 additions & 1 deletion bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch.nn as nn
import transformer_engine.pytorch
import transformers
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
from torch.distributed.tensor.placement_types import Replicate
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention.inference import PagedKVCacheManager
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
Expand Down Expand Up @@ -49,6 +51,18 @@ class NVLlamaConfig(LlamaConfig):
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
tensor_parallel: bool = False
sequence_parallel: bool = False
tp_size: int = 1
tp_mesh: torch.distributed.DeviceMesh | None = None
weight_mesh: torch.distributed.DeviceMesh | None = None

def to_dict(self):
config_dict = super().to_dict()
# DeviceMesh is not serializable. Don't checkpoint it.
config_dict.pop("tp_mesh", None)
config_dict.pop("weight_mesh", None)
return config_dict


class NVLlamaPreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -114,9 +128,35 @@ def __init__(self, config: LlamaConfig):
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.tp_mesh = config.tp_mesh

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)

# Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP.
if config.tensor_parallel:
assert (
self.tp_mesh is not None,
"[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh."
)
assert (
self.tp_mesh.size() == config.tp_size,
f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) "
f"does not match configured TP size ({config.tp_size})."
)
# NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding
# during HuggingFace post-init, this will automatically convert the TELinear head
# weight into a DTensor with the correct sharding placements prior to FSDP2
# fully_shard(), and no need to call TELinear.set_device_mesh().
parallelize_module(
self.embed_tokens,
self.tp_mesh,
# Un-sharded output activations for compatible input to TETransformer.
# NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1)
# RowwiseParallel doesn't support output_layouts=Replicate() with
# torch.compile: https://github.com/pytorch/torchtitan/issues/534
ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate())
)

def _init_method(x):
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)

Expand All @@ -142,6 +182,11 @@ def _init_method(x):
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=_init_method,
output_layer_init_method=_init_method,
set_parallel_mode=config.tensor_parallel,
sequence_parallel=config.sequence_parallel,
tp_size=config.tp_size,
tp_mesh=config.tp_mesh,
weight_mesh=config.weight_mesh,
)
for layer_idx in range(config.num_hidden_layers)
]
Expand All @@ -152,6 +197,8 @@ def _init_method(x):
dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
)
# Norm modules are non-Base TransformerEngine modules that require a manual call for TP.
self.norm.set_device_mesh(tp_mesh=config.tp_mesh, weight_mesh=config.weight_mesh)

# We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
# LlamaRotaryEmbedding.
Expand Down Expand Up @@ -283,6 +330,7 @@ def __init__(self, config):
super().__init__(config)
self.model = NVLlamaModel(config)
self.vocab_size = config.vocab_size
self.tp_mesh = config.tp_mesh
with transformer_engine.pytorch.quantized_model_init(enabled=False):
self.lm_head = transformer_engine.pytorch.Linear(
config.hidden_size,
Expand All @@ -291,9 +339,19 @@ def __init__(self, config):
params_dtype=config.dtype,
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
parallel_mode="row" if config.tensor_parallel else None,
# This scatters your output, not ever needed for final layer.
# Will all-reduce the output instead, as required.
sequence_parallel=False,
tp_size=config.tp_size,
)
if self.config.tensor_parallel:
# If using tensor parallelism, the head weights have already been tied
# to the embedding weights. Just set the tensor parallel group for TE.
# No parameter quantization either, so no need for weight_mesh.
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())

# Initialize weights and apply final processing
# Initialize weights and apply final processing. Ties weights.
self.post_init()

def forward(
Expand Down Expand Up @@ -346,6 +404,13 @@ def forward(
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep

if self.config.tensor_parallel:
# If using TP, shard your activation across the TP group,
# to support row-wise tensor parallelism in the LM head.
tp_rank = self.tp_mesh.get_local_rank()
tp_stride = hidden_states.shape[-1] // self.config.tp_size
hidden_states = hidden_states[:, :, tp_rank*tp_stride:(tp_rank + 1)*tp_stride]

with transformer_engine.pytorch.autocast(enabled=False):
if hidden_states.ndim == 3:
logits = self.lm_head(hidden_states[:, slice_indices, :])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
logger.setLevel(logging.INFO)


@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2")
@hydra.main(config_path="hydra_config", config_name="L2_sanity_nd", version_base="1.2")
def main(args: DictConfig) -> float | None:
"""Train Llama3 with TE layers using FSDP2 with Context Parallelism.

Expand All @@ -73,8 +73,8 @@ def main(args: DictConfig) -> float | None:

device_mesh = init_device_mesh(
"cuda",
mesh_shape=(dist_config.world_size // args.cp_size, args.cp_size),
mesh_dim_names=("dp", "cp"),
mesh_shape=(dist_config.world_size // (args.cp_size * args.tp_size), args.cp_size, args.tp_size),
mesh_dim_names=("dp", "cp", "tp"),
)
logger.info("Created device mesh: %s", device_mesh)

Expand All @@ -85,6 +85,22 @@ def main(args: DictConfig) -> float | None:

# --- Model Initialization ---
config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)

# Identify DeviceMesh that are propagated to `set_device_mesh` in TransformerEngine modules.
# These will convert TransformerEngine parameters into DTensors. Alternatively, users can
# manually call the conversion using `TransformerEngineModule.set_device_mesh(...)`` before
# `reset_parameters` (which triggers quantization) if the module supports DTensor parameters.
if config.tensor_parallel:
config.tp_mesh = device_mesh["tp"]
if (
args.fp8_config.quantized_model_init_kwargs.enabled
and isinstance(fp8_recipe, transformer_engine.common.recipe.Float8CurrentScaling)
):
# When using per-tensor FP8 recipes for quantized parameters, TransformerEngine
# requires a weight sharding mesh for absmax reduction across distributed weights.
# If not provided, will default to DTensor.device_mesh.get_group(), which is not
# appropriate if HSDP (DP-Replicate x DP-Shard) is used.
config.weight_mesh = device_mesh["dp", "cp", "tp"]._flatten("weight_mesh")

# Optionally use transformer engine to initialize only fp8 versions of weights by setting
# `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
Expand All @@ -100,6 +116,8 @@ def main(args: DictConfig) -> float | None:
logger.info("Initialized Model:\n%s", model)

# --- Distributed Wrapping (FSDP2 + CP) ---

# Create a flattened mesh for FSDP2-CP sharding. This will shard the model across both the DP and CP ranks.
cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp")

# Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
Expand All @@ -108,7 +126,7 @@ def main(args: DictConfig) -> float | None:
fully_shard(layer, mesh=cp_dp_mesh)
fully_shard(model, mesh=cp_dp_mesh)

# Attach the CP group to the model.
# Attach the CP ProcessGroup to the TransformerEngine model.
for layer in model.model.layers:
layer.set_context_parallel_group(
device_mesh["cp"].get_group(),
Expand Down Expand Up @@ -137,9 +155,12 @@ def main(args: DictConfig) -> float | None:
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2)

# We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP)
# ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
if device_mesh["cp"].get_local_rank() == 0:
# We only create the dataloader on rank 0, which is responsible for loading data for all CP (and TP) ranks.
# This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
cp_tp_mesh = device_mesh["cp", "tp"]._flatten(mesh_dim_name="cp_tp")
if cp_tp_mesh.get_local_rank() == 0:
# We only create the dataloader on CP-TP Rank 0 and pass it to a ContextParallelDataLoaderWrapper
# that will shard, replicate, and distribute the data across the flattened CP and TP group.
if args.use_sequence_packing:
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
else:
Expand All @@ -156,8 +177,8 @@ def main(args: DictConfig) -> float | None:
train_dataloader = None
dataset_or_sampler = None

# On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0.
train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"])
# Deliver CP-sharded replicates to a flattened CP-TP mesh.
train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_tp_mesh)

# --- Checkpoint Resume ---
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None
Expand All @@ -170,7 +191,6 @@ def main(args: DictConfig) -> float | None:
ckpt_path=ckpt_path,
dist_config=dist_config,
dataloader=train_dataloader,
process_group=cp_dp_mesh.get_group(),
)
logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch)
else:
Expand Down Expand Up @@ -226,6 +246,13 @@ def main(args: DictConfig) -> float | None:
)

if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
if args.checkpoint.async_save and args.fp8_config.quantized_model_init_kwargs.enabled:
logger.info(
"Asynchronous checkpointing is not supported with TransformerEngine "
"quantized parameters and FSDP2. Using synchronous checkpointing "
"(checkpoint.async_save=false)..."
)
OmegaConf.update(args, "checkpoint.async_save", False)
save_checkpoint_fsdp2(
model=model,
optimizer=optimizer,
Expand All @@ -235,7 +262,6 @@ def main(args: DictConfig) -> float | None:
epoch=epoch,
dist_config=dist_config,
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
process_group=cp_dp_mesh.get_group(),
max_checkpoints=args.checkpoint.max_checkpoints,
async_save=args.checkpoint.async_save,
)
Expand Down Expand Up @@ -268,4 +294,4 @@ def main(args: DictConfig) -> float | None:


if __name__ == "__main__":
main()
main()