|
| 1 | +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# See LICENSE for license information. |
| 4 | + |
| 5 | +"""FSDP2 distributed training with quantized model initialization. |
| 6 | +
|
| 7 | +Extends the single-GPU ``main.py`` example to multi-GPU training using |
| 8 | +PyTorch-native FSDP2 (``fully_shard``). The script demonstrates: |
| 9 | +
|
| 10 | +1. **Meta-device initialization** -- Model parameters are created on the |
| 11 | + ``meta`` device (zero memory), then FSDP2 sharding is applied, and |
| 12 | + finally ``reset_parameters()`` materializes and quantizes only the |
| 13 | + local shards on each rank's GPU. |
| 14 | +2. ``quantized_model_init`` -- Flags the model for FP8 weight initialization |
| 15 | + (actual quantization happens in ``reset_parameters`` after sharding). |
| 16 | +3. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer. |
| 17 | +4. ``FusedAdam`` with FP32 master weights for full-precision training updates. |
| 18 | +
|
| 19 | +.. note:: |
| 20 | + ``fuse_wgrad_accumulation`` is **not** used here. That feature writes |
| 21 | + weight gradients directly into ``main_grad`` buffers, bypassing the |
| 22 | + autograd gradient flow. FSDP2 requires gradients to go through its |
| 23 | + reduce-scatter, so ``fuse_wgrad_accumulation`` needs Megatron-Core's |
| 24 | + FSDP integration (which provides ``get_main_grad()``). |
| 25 | +
|
| 26 | +Usage:: |
| 27 | +
|
| 28 | + torchrun --nproc-per-node 2 fully_shard.py |
| 29 | +""" |
| 30 | + |
| 31 | +import os |
| 32 | + |
| 33 | +import torch |
| 34 | +import torch.distributed as dist |
| 35 | +import torch.nn.functional as F |
| 36 | +from torch.distributed._composable.fsdp import fully_shard |
| 37 | +from torch.distributed.device_mesh import DeviceMesh |
| 38 | +from torch.distributed.tensor import DTensor |
| 39 | + |
| 40 | +import transformer_engine.pytorch as te |
| 41 | +from transformer_engine.pytorch import QuantizedTensor |
| 42 | +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule |
| 43 | + |
| 44 | +# ── Configuration (matches main.py) ────────────────────────────────── |
| 45 | +HIDDEN_SIZE = 256 |
| 46 | +FFN_HIDDEN_SIZE = 1024 |
| 47 | +NUM_ATTENTION_HEADS = 8 |
| 48 | +NUM_LAYERS = 3 |
| 49 | +SEQ_LEN = 32 |
| 50 | +BATCH_PER_RANK = 2 |
| 51 | +NUM_STEPS = 5 |
| 52 | +DTYPE = torch.bfloat16 |
| 53 | + |
| 54 | + |
| 55 | +def dist_print(msg): |
| 56 | + """Print only on rank 0.""" |
| 57 | + if int(os.environ.get("RANK", "0")) == 0: |
| 58 | + print(msg) |
| 59 | + |
| 60 | + |
| 61 | +def main(): |
| 62 | + # ── 1. Distributed setup ───────────────────────────────────────── |
| 63 | + assert "TORCHELASTIC_RUN_ID" in os.environ, ( |
| 64 | + "This script must be launched with torchrun, e.g.:\n" |
| 65 | + " torchrun --nproc-per-node 2 fully_shard.py" |
| 66 | + ) |
| 67 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 68 | + local_rank = int(os.environ["LOCAL_RANK"]) |
| 69 | + |
| 70 | + torch.cuda.set_device(local_rank) |
| 71 | + dist.init_process_group(backend="nccl") |
| 72 | + device = torch.device(f"cuda:{local_rank}") |
| 73 | + |
| 74 | + torch.manual_seed(42) |
| 75 | + torch.cuda.manual_seed(42) |
| 76 | + |
| 77 | + # ── 2. Create model on meta device (zero memory) ──────────────── |
| 78 | + # quantized_model_init sets the flag for FP8 weight initialization, |
| 79 | + # but with device="meta" no actual memory is allocated yet. |
| 80 | + with te.quantized_model_init(enabled=True): |
| 81 | + model = torch.nn.Sequential( |
| 82 | + *[ |
| 83 | + te.TransformerLayer( |
| 84 | + HIDDEN_SIZE, |
| 85 | + FFN_HIDDEN_SIZE, |
| 86 | + NUM_ATTENTION_HEADS, |
| 87 | + fuse_qkv_params=True, |
| 88 | + params_dtype=DTYPE, |
| 89 | + hidden_dropout=0.0, |
| 90 | + attention_dropout=0.0, |
| 91 | + device="meta", |
| 92 | + ) |
| 93 | + for _ in range(NUM_LAYERS) |
| 94 | + ] |
| 95 | + ) |
| 96 | + |
| 97 | + # Verify all parameters are on meta device (no GPU memory used). |
| 98 | + for name, param in model.named_parameters(): |
| 99 | + assert param.device == torch.device("meta"), f"{name} is not on meta device" |
| 100 | + dist_print("Model created on meta device (zero GPU memory).") |
| 101 | + |
| 102 | + # ── 3. FSDP2 sharding ──────────────────────────────────────────── |
| 103 | + # Apply sharding to the meta-device model. FSDP2 wraps parameters |
| 104 | + # as DTensors but no GPU memory is allocated yet. |
| 105 | + mesh = DeviceMesh("cuda", list(range(world_size))) |
| 106 | + for child in model.children(): |
| 107 | + fully_shard(child, mesh=mesh) |
| 108 | + fully_shard(model, mesh=mesh) |
| 109 | + dist_print("FSDP2 sharding applied to meta-device model.") |
| 110 | + |
| 111 | + # ── 4. Materialize parameters on GPU ────────────────────────────── |
| 112 | + # reset_parameters() on each TE module materializes the local shard |
| 113 | + # on CUDA, applies weight initialization, and quantizes to FP8. |
| 114 | + for module in model.modules(): |
| 115 | + if isinstance(module, TransformerEngineBaseModule): |
| 116 | + module.reset_parameters() |
| 117 | + |
| 118 | + # Post-materialization verification. |
| 119 | + for name, param in model.named_parameters(): |
| 120 | + assert isinstance(param, DTensor), f"{name} is not a DTensor after sharding" |
| 121 | + qt_count = sum( |
| 122 | + 1 |
| 123 | + for _, p in model.named_parameters() |
| 124 | + if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor) |
| 125 | + ) |
| 126 | + assert qt_count > 0, "No QuantizedTensor local tensors after materialization" |
| 127 | + dist_print( |
| 128 | + f"Parameters materialized: {qt_count} FP8 (QuantizedTensor) weight params " |
| 129 | + "wrapped in DTensors." |
| 130 | + ) |
| 131 | + |
| 132 | + # ── 5. Optimizer ───────────────────────────────────────────────── |
| 133 | + optimizer = te.optimizers.FusedAdam( |
| 134 | + model.parameters(), |
| 135 | + lr=1e-3, |
| 136 | + master_weights=True, |
| 137 | + master_weight_dtype=torch.float32, |
| 138 | + ) |
| 139 | + dist_print("Using FusedAdam with master_weights=True.") |
| 140 | + |
| 141 | + # ── 6. Training loop ───────────────────────────────────────────── |
| 142 | + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) |
| 143 | + target = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) |
| 144 | + |
| 145 | + for step in range(NUM_STEPS): |
| 146 | + optimizer.zero_grad(set_to_none=True) |
| 147 | + |
| 148 | + with te.autocast(enabled=True): |
| 149 | + output = model(x) |
| 150 | + |
| 151 | + loss = F.mse_loss(output, target) |
| 152 | + loss.backward() |
| 153 | + optimizer.step() |
| 154 | + dist_print(f" Step {step}: loss = {loss.item():.6f}") |
| 155 | + |
| 156 | + # ── 7. Post-training assertions ────────────────────────────────── |
| 157 | + dist_print("\nVerifying invariants ...") |
| 158 | + |
| 159 | + qt_after = 0 |
| 160 | + for name, param in model.named_parameters(): |
| 161 | + assert isinstance(param, DTensor), f"{name} lost DTensor wrapping" |
| 162 | + if isinstance(param._local_tensor, QuantizedTensor): |
| 163 | + qt_after += 1 |
| 164 | + assert qt_after > 0, "No QuantizedTensor local tensors after training" |
| 165 | + dist_print(f" {qt_after} params still have QuantizedTensor local tensors.") |
| 166 | + |
| 167 | + # Optimizer states: master weights and moments should be float32. |
| 168 | + for param in model.parameters(): |
| 169 | + state = optimizer.state[param] |
| 170 | + if "master_param" in state: |
| 171 | + assert ( |
| 172 | + state["master_param"].dtype == torch.float32 |
| 173 | + ), f"Master weight dtype {state['master_param'].dtype}, expected float32" |
| 174 | + assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32" |
| 175 | + assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32" |
| 176 | + |
| 177 | + dist_print("All assertions passed!") |
| 178 | + dist_print(" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor") |
| 179 | + dist_print(" - Optimizer master weights: float32") |
| 180 | + dist_print(" - Optimizer states (exp_avg, exp_avg_sq): float32") |
| 181 | + |
| 182 | + # ── 8. Distributed checkpoint: save and load ───────────────────── |
| 183 | + # torch.distributed.checkpoint (DCP) saves sharded state — each rank |
| 184 | + # writes only its local shard. This preserves FP8 compute weights |
| 185 | + # and the full optimizer state (master weights, moments, step count). |
| 186 | + import torch.distributed.checkpoint as dcp |
| 187 | + from torch.distributed.checkpoint.state_dict import ( |
| 188 | + StateDictOptions, |
| 189 | + get_model_state_dict, |
| 190 | + get_optimizer_state_dict, |
| 191 | + ) |
| 192 | + |
| 193 | + # Use a fixed path so all ranks agree on the checkpoint location. |
| 194 | + checkpoint_dir = "/tmp/te_fsdp2_example_checkpoint" |
| 195 | + dist_print(f"\nSaving distributed checkpoint to {checkpoint_dir} ...") |
| 196 | + |
| 197 | + # Save sharded checkpoint. DCP handles DTensor shards natively — |
| 198 | + # each rank writes only its local shard to the filesystem. |
| 199 | + dcp.save( |
| 200 | + {"model": model.state_dict(), "optimizer": optimizer.state_dict()}, |
| 201 | + checkpoint_id=checkpoint_dir, |
| 202 | + ) |
| 203 | + dist_print(" Checkpoint saved (FP8 weights + optimizer state).") |
| 204 | + |
| 205 | + # Load checkpoint back. Provide empty state dict containers with the |
| 206 | + # same structure; DCP fills them from the saved files. |
| 207 | + state_to_load = {"model": model.state_dict(), "optimizer": optimizer.state_dict()} |
| 208 | + dcp.load(state_to_load, checkpoint_id=checkpoint_dir) |
| 209 | + model.load_state_dict(state_to_load["model"]) |
| 210 | + optimizer.load_state_dict(state_to_load["optimizer"]) |
| 211 | + dist_print(" Checkpoint loaded — FP8 weights and optimizer state restored.") |
| 212 | + |
| 213 | + # Verify training continues after checkpoint load. |
| 214 | + optimizer.zero_grad(set_to_none=True) |
| 215 | + with te.autocast(enabled=True): |
| 216 | + output = model(x) |
| 217 | + loss = F.mse_loss(output, target) |
| 218 | + loss.backward() |
| 219 | + optimizer.step() |
| 220 | + dist_print(f" Post-checkpoint training step: loss = {loss.item():.6f}") |
| 221 | + |
| 222 | + # ── 9. Save full-precision (FP32) model to safetensors ─────────── |
| 223 | + # For inference or fine-tuning you typically want FP32 weights, not |
| 224 | + # FP8 compute weights. The optimizer's master weight copies are the |
| 225 | + # authoritative FP32 values (more precise than dequantizing FP8). |
| 226 | + # All ranks must participate in gathering; only rank 0 saves. |
| 227 | + from safetensors.torch import save_file |
| 228 | + |
| 229 | + full_opts = StateDictOptions(full_state_dict=True, cpu_offload=True) |
| 230 | + |
| 231 | + full_model_state = get_model_state_dict(model, options=full_opts) |
| 232 | + full_opt_state = get_optimizer_state_dict(model, optimizer, options=full_opts) |
| 233 | + |
| 234 | + rank = int(os.environ.get("RANK", "0")) |
| 235 | + if rank == 0: |
| 236 | + fp32_state = {} |
| 237 | + opt_param_states = full_opt_state.get("state", {}) |
| 238 | + |
| 239 | + for key, value in full_model_state.items(): |
| 240 | + if key in opt_param_states and "master_param" in opt_param_states[key]: |
| 241 | + # Prefer optimizer's FP32 master weight (maintained throughout training). |
| 242 | + fp32_state[key] = opt_param_states[key]["master_param"].float() |
| 243 | + elif isinstance(value, QuantizedTensor): |
| 244 | + # Fallback: dequantize FP8 → FP32 (e.g. if master_weights was off). |
| 245 | + fp32_state[key] = value.dequantize().float() |
| 246 | + else: |
| 247 | + # Non-FP8 params (e.g. LayerNorm weights): cast to FP32. |
| 248 | + fp32_state[key] = value.float() |
| 249 | + |
| 250 | + save_path = "/tmp/te_fsdp2_example_model_fp32.safetensors" |
| 251 | + save_file(fp32_state, save_path) |
| 252 | + dist_print(f"\nSaved FP32 model ({len(fp32_state)} params) to {save_path}") |
| 253 | + |
| 254 | + # Quick verification: all saved tensors are float32. |
| 255 | + from safetensors.torch import load_file |
| 256 | + |
| 257 | + loaded = load_file(save_path) |
| 258 | + for k, v in loaded.items(): |
| 259 | + assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}" |
| 260 | + dist_print(f" Verified: all {len(loaded)} tensors are float32.") |
| 261 | + |
| 262 | + dist.destroy_process_group() |
| 263 | + |
| 264 | + |
| 265 | +if __name__ == "__main__": |
| 266 | + main() |
0 commit comments