diff --git a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py index bd6339035..0573ca9d1 100644 --- a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py +++ b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py @@ -34,6 +34,7 @@ from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save from torch.distributed.checkpoint.state_dict_saver import save as dcp_save from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.tensor import DTensor from torchdata.stateful_dataloader import StatefulDataLoader from distributed_config import DistributedConfig @@ -219,8 +220,28 @@ class AppState(Stateful): epoch: int = 0 def state_dict(self): - """Get the state dict for the model, optimizer, scheduler, and step.""" + """ + Get the state dict for the model, optimizer, scheduler, and step. + This factory both retrieves the model state dictionary when saving + checkpoints and initializes a destination for the state read from + DCP checkpoint files when loading checkpoints. + """ model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) + for fqn in list(model_state_dict.keys()): + # Get the model parameter. + model_param = model_state_dict[fqn] + if isinstance(model_param, DTensor): + model_param = model_param.to_local() + if model_param.numel() == 0 and fqn in optimizer_state_dict['state']: + # Empty model parameter. Clear the associated optimizer state + # when initializing the optimizer state upon DCP load, because + # empty optimizer state DTensors are not checkpointed with DCP, + # yet get_state_dict / _init_optim_state produce empty Tensors. + # TransformerEngine uses empty Tensors for dummy Parameters. + optimizer_state_dict['state'][fqn] = {} + if fqn.endswith("._extra_state"): + # Evict `_extra_state` quantization data from model checkpoint. + model_state_dict.pop(fqn) return { "model": model_state_dict, "optim": optimizer_state_dict, @@ -230,12 +251,19 @@ def state_dict(self): } def load_state_dict(self, state_dict: dict): - """Load the state dict for the model, optimizer, scheduler, and step.""" + """ + Load the state dict for the model, optimizer, scheduler, and step. + Given the checkpoint-loaded state_dict, set the state of the model, + optimizer, scheduler, step, and epoch to the values in state_dict. + """ set_state_dict( self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"], + # Non-strict checkpoint loading ignores empty optimizer states, + # skips loading non-FP8 checkpoint weights (e.g. _extra_state). + options=StateDictOptions(strict=False), ) self.scheduler.load_state_dict(state_dict["scheduler"]) self.step = state_dict["step"] diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_sanity_nd.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_sanity_nd.yaml new file mode 100644 index 000000000..f6eb562b6 --- /dev/null +++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_sanity_nd.yaml @@ -0,0 +1,31 @@ +defaults: + - L0_sanity + - _self_ + +tp_size: 2 +cp_size: 2 + +dataset: + # CP2 * (8 for FP8 Activations, 16 for FP8 Parameters) + pad_sequences_to_be_divisible_by: 32 + +fp8_config: + enabled: true + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + quantized_model_init_kwargs: + # TODO(@cspades): Quantized parameters are + # NOT supported with DCP checkpointing. + enabled: true + +checkpoint: + ckpt_dir: ./fsdp_tp_ckpts + save_final_model: true + +config_kwargs: + attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware. + self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs. + tensor_parallel: true # Tensor Parallelism for TE + sequence_parallel: true # Sequence parallelism for LayerNorm on TP ranks. + tp_size: ${tp_size} # Tensor Parallel Size \ No newline at end of file diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py index 274320f86..6df9f3019 100644 --- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py +++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py @@ -22,6 +22,8 @@ import torch.nn as nn import transformer_engine.pytorch import transformers +from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module +from torch.distributed.tensor.placement_types import Replicate from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention.inference import PagedKVCacheManager from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding @@ -49,6 +51,18 @@ class NVLlamaConfig(LlamaConfig): # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) attn_input_format: str = "thd" self_attn_mask_type: str = "padding_causal" + tensor_parallel: bool = False + sequence_parallel: bool = False + tp_size: int = 1 + tp_mesh: torch.distributed.DeviceMesh | None = None + weight_mesh: torch.distributed.DeviceMesh | None = None + + def to_dict(self): + config_dict = super().to_dict() + # DeviceMesh is not serializable. Don't checkpoint it. + config_dict.pop("tp_mesh", None) + config_dict.pop("weight_mesh", None) + return config_dict class NVLlamaPreTrainedModel(PreTrainedModel): @@ -114,9 +128,35 @@ def __init__(self, config: LlamaConfig): self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.tp_mesh = config.tp_mesh self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) + # Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP. + if config.tensor_parallel: + assert ( + self.tp_mesh is not None, + "[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh." + ) + assert ( + self.tp_mesh.size() == config.tp_size, + f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) " + f"does not match configured TP size ({config.tp_size})." + ) + # NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding + # during HuggingFace post-init, this will automatically convert the TELinear head + # weight into a DTensor with the correct sharding placements prior to FSDP2 + # fully_shard(), and no need to call TELinear.set_device_mesh(). + parallelize_module( + self.embed_tokens, + self.tp_mesh, + # Un-sharded output activations for compatible input to TETransformer. + # NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1) + # RowwiseParallel doesn't support output_layouts=Replicate() with + # torch.compile: https://github.com/pytorch/torchtitan/issues/534 + ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()) + ) + def _init_method(x): torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) @@ -142,6 +182,11 @@ def _init_method(x): device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", init_method=_init_method, output_layer_init_method=_init_method, + set_parallel_mode=config.tensor_parallel, + sequence_parallel=config.sequence_parallel, + tp_size=config.tp_size, + tp_mesh=config.tp_mesh, + weight_mesh=config.weight_mesh, ) for layer_idx in range(config.num_hidden_layers) ] @@ -152,6 +197,8 @@ def _init_method(x): dtype=config.dtype, device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", ) + # Norm modules are non-Base TransformerEngine modules that require a manual call for TP. + self.norm.set_device_mesh(tp_mesh=config.tp_mesh, weight_mesh=config.weight_mesh) # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original # LlamaRotaryEmbedding. @@ -283,6 +330,7 @@ def __init__(self, config): super().__init__(config) self.model = NVLlamaModel(config) self.vocab_size = config.vocab_size + self.tp_mesh = config.tp_mesh with transformer_engine.pytorch.quantized_model_init(enabled=False): self.lm_head = transformer_engine.pytorch.Linear( config.hidden_size, @@ -291,9 +339,19 @@ def __init__(self, config): params_dtype=config.dtype, device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + parallel_mode="row" if config.tensor_parallel else None, + # This scatters your output, not ever needed for final layer. + # Will all-reduce the output instead, as required. + sequence_parallel=False, + tp_size=config.tp_size, ) + if self.config.tensor_parallel: + # If using tensor parallelism, the head weights have already been tied + # to the embedding weights. Just set the tensor parallel group for TE. + # No parameter quantization either, so no need for weight_mesh. + self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group()) - # Initialize weights and apply final processing + # Initialize weights and apply final processing. Ties weights. self.post_init() def forward( @@ -346,6 +404,13 @@ def forward( # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + if self.config.tensor_parallel: + # If using TP, shard your activation across the TP group, + # to support row-wise tensor parallelism in the LM head. + tp_rank = self.tp_mesh.get_local_rank() + tp_stride = hidden_states.shape[-1] // self.config.tp_size + hidden_states = hidden_states[:, :, tp_rank*tp_stride:(tp_rank + 1)*tp_stride] + with transformer_engine.pytorch.autocast(enabled=False): if hidden_states.ndim == 3: logits = self.lm_head(hidden_states[:, slice_indices, :]) diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py similarity index 80% rename from bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py rename to bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py index 9ad3d0e29..49a14b5c8 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py @@ -57,7 +57,7 @@ logger.setLevel(logging.INFO) -@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2") +@hydra.main(config_path="hydra_config", config_name="L2_sanity_nd", version_base="1.2") def main(args: DictConfig) -> float | None: """Train Llama3 with TE layers using FSDP2 with Context Parallelism. @@ -73,8 +73,8 @@ def main(args: DictConfig) -> float | None: device_mesh = init_device_mesh( "cuda", - mesh_shape=(dist_config.world_size // args.cp_size, args.cp_size), - mesh_dim_names=("dp", "cp"), + mesh_shape=(dist_config.world_size // (args.cp_size * args.tp_size), args.cp_size, args.tp_size), + mesh_dim_names=("dp", "cp", "tp"), ) logger.info("Created device mesh: %s", device_mesh) @@ -85,6 +85,22 @@ def main(args: DictConfig) -> float | None: # --- Model Initialization --- config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs) + + # Identify DeviceMesh that are propagated to `set_device_mesh` in TransformerEngine modules. + # These will convert TransformerEngine parameters into DTensors. Alternatively, users can + # manually call the conversion using `TransformerEngineModule.set_device_mesh(...)`` before + # `reset_parameters` (which triggers quantization) if the module supports DTensor parameters. + if config.tensor_parallel: + config.tp_mesh = device_mesh["tp"] + if ( + args.fp8_config.quantized_model_init_kwargs.enabled + and isinstance(fp8_recipe, transformer_engine.common.recipe.Float8CurrentScaling) + ): + # When using per-tensor FP8 recipes for quantized parameters, TransformerEngine + # requires a weight sharding mesh for absmax reduction across distributed weights. + # If not provided, will default to DTensor.device_mesh.get_group(), which is not + # appropriate if HSDP (DP-Replicate x DP-Shard) is used. + config.weight_mesh = device_mesh["dp", "cp", "tp"]._flatten("weight_mesh") # Optionally use transformer engine to initialize only fp8 versions of weights by setting # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16 @@ -100,6 +116,8 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) # --- Distributed Wrapping (FSDP2 + CP) --- + + # Create a flattened mesh for FSDP2-CP sharding. This will shard the model across both the DP and CP ranks. cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp") # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers. @@ -108,7 +126,7 @@ def main(args: DictConfig) -> float | None: fully_shard(layer, mesh=cp_dp_mesh) fully_shard(model, mesh=cp_dp_mesh) - # Attach the CP group to the model. + # Attach the CP ProcessGroup to the TransformerEngine model. for layer in model.model.layers: layer.set_context_parallel_group( device_mesh["cp"].get_group(), @@ -137,9 +155,12 @@ def main(args: DictConfig) -> float | None: logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2") OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2) - # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP) - # ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline. - if device_mesh["cp"].get_local_rank() == 0: + # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and TP) ranks. + # This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline. + cp_tp_mesh = device_mesh["cp", "tp"]._flatten(mesh_dim_name="cp_tp") + if cp_tp_mesh.get_local_rank() == 0: + # We only create the dataloader on CP-TP Rank 0 and pass it to a ContextParallelDataLoaderWrapper + # that will shard, replicate, and distribute the data across the flattened CP and TP group. if args.use_sequence_packing: train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset) else: @@ -156,8 +177,8 @@ def main(args: DictConfig) -> float | None: train_dataloader = None dataset_or_sampler = None - # On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0. - train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"]) + # Deliver CP-sharded replicates to a flattened CP-TP mesh. + train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_tp_mesh) # --- Checkpoint Resume --- ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None @@ -170,7 +191,6 @@ def main(args: DictConfig) -> float | None: ckpt_path=ckpt_path, dist_config=dist_config, dataloader=train_dataloader, - process_group=cp_dp_mesh.get_group(), ) logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) else: @@ -226,6 +246,13 @@ def main(args: DictConfig) -> float | None: ) if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps): + if args.checkpoint.async_save and args.fp8_config.quantized_model_init_kwargs.enabled: + logger.info( + "Asynchronous checkpointing is not supported with TransformerEngine " + "quantized parameters and FSDP2. Using synchronous checkpointing " + "(checkpoint.async_save=false)..." + ) + OmegaConf.update(args, "checkpoint.async_save", False) save_checkpoint_fsdp2( model=model, optimizer=optimizer, @@ -235,7 +262,6 @@ def main(args: DictConfig) -> float | None: epoch=epoch, dist_config=dist_config, dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None, - process_group=cp_dp_mesh.get_group(), max_checkpoints=args.checkpoint.max_checkpoints, async_save=args.checkpoint.async_save, ) @@ -268,4 +294,4 @@ def main(args: DictConfig) -> float | None: if __name__ == "__main__": - main() + main() \ No newline at end of file