|
| 1 | +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# |
| 3 | +# See LICENSE for license information. |
| 4 | + |
| 5 | +"""FSDP2 distributed training with quantized model initialization. |
| 6 | +
|
| 7 | +Extends the single-GPU ``main.py`` example to multi-GPU training using |
| 8 | +PyTorch-native FSDP2 (``fully_shard``). The script demonstrates: |
| 9 | +
|
| 10 | +1. ``quantized_model_init`` -- FP8 weight initialization (same as main.py). |
| 11 | +2. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer. |
| 12 | +3. ``save_custom_attrs`` / ``restore_custom_attrs`` -- Preserve custom |
| 13 | + Python-level attributes on QuantizedTensor parameters that FSDP2's |
| 14 | + DTensor wrapping would otherwise drop. |
| 15 | +4. ``FusedAdam`` with FP32 master weights for full-precision training updates. |
| 16 | +
|
| 17 | +.. note:: |
| 18 | + ``fuse_wgrad_accumulation`` is **not** used here. That feature writes |
| 19 | + weight gradients directly into ``main_grad`` buffers, bypassing the |
| 20 | + autograd gradient flow. FSDP2 requires gradients to go through its |
| 21 | + reduce-scatter, so ``fuse_wgrad_accumulation`` needs Megatron-Core's |
| 22 | + FSDP integration (which provides ``get_main_grad()``). |
| 23 | +
|
| 24 | +Usage:: |
| 25 | +
|
| 26 | + torchrun --nproc-per-node 2 fully_shard.py |
| 27 | +""" |
| 28 | + |
| 29 | +import os |
| 30 | + |
| 31 | +import torch |
| 32 | +import torch.distributed as dist |
| 33 | +import torch.nn.functional as F |
| 34 | +from torch.distributed._composable.fsdp import fully_shard |
| 35 | +from torch.distributed.device_mesh import DeviceMesh |
| 36 | +from torch.distributed.tensor import DTensor |
| 37 | + |
| 38 | +import transformer_engine.pytorch as te |
| 39 | +from transformer_engine.pytorch import QuantizedTensor |
| 40 | + |
| 41 | +# ── Configuration (matches main.py) ────────────────────────────────── |
| 42 | +HIDDEN_SIZE = 256 |
| 43 | +FFN_HIDDEN_SIZE = 1024 |
| 44 | +NUM_ATTENTION_HEADS = 8 |
| 45 | +NUM_LAYERS = 3 |
| 46 | +SEQ_LEN = 32 |
| 47 | +BATCH_PER_RANK = 2 |
| 48 | +NUM_STEPS = 5 |
| 49 | +DTYPE = torch.bfloat16 |
| 50 | + |
| 51 | + |
| 52 | +def dist_print(msg): |
| 53 | + """Print only on rank 0.""" |
| 54 | + if int(os.environ.get("RANK", "0")) == 0: |
| 55 | + print(msg) |
| 56 | + |
| 57 | + |
| 58 | +# ── Save / restore custom attributes across FSDP2 sharding ────────── |
| 59 | +# FSDP2's fully_shard replaces parameters with DTensors, which drops any |
| 60 | +# custom Python-level attributes. These helpers preserve them. |
| 61 | +# (Pattern from tests/pytorch/distributed/run_fsdp2_model.py) |
| 62 | + |
| 63 | + |
| 64 | +def save_custom_attrs(module): |
| 65 | + """Save custom attributes from all parameters before FSDP2 sharding.""" |
| 66 | + custom_attrs = {} |
| 67 | + for name, param in module.named_parameters(): |
| 68 | + if isinstance(param, QuantizedTensor): |
| 69 | + ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")] |
| 70 | + else: |
| 71 | + ignore_keys = [] |
| 72 | + attrs = vars(param) |
| 73 | + custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys} |
| 74 | + return custom_attrs |
| 75 | + |
| 76 | + |
| 77 | +def restore_custom_attrs(module, custom_attrs): |
| 78 | + """Restore saved custom attributes after FSDP2 sharding.""" |
| 79 | + for name, param in module.named_parameters(): |
| 80 | + if name in custom_attrs: |
| 81 | + for attr_name, attr_value in custom_attrs[name].items(): |
| 82 | + setattr(param, attr_name, attr_value) |
| 83 | + |
| 84 | + |
| 85 | +def main(): |
| 86 | + # ── 1. Distributed setup ───────────────────────────────────────── |
| 87 | + assert "TORCHELASTIC_RUN_ID" in os.environ, ( |
| 88 | + "This script must be launched with torchrun, e.g.:\n" |
| 89 | + " torchrun --nproc-per-node 2 fully_shard.py" |
| 90 | + ) |
| 91 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 92 | + local_rank = int(os.environ["LOCAL_RANK"]) |
| 93 | + |
| 94 | + torch.cuda.set_device(local_rank) |
| 95 | + dist.init_process_group(backend="nccl") |
| 96 | + device = torch.device(f"cuda:{local_rank}") |
| 97 | + |
| 98 | + torch.manual_seed(42) |
| 99 | + torch.cuda.manual_seed(42) |
| 100 | + |
| 101 | + # ── 2. Create model with quantized (FP8) parameters ────────────── |
| 102 | + with te.quantized_model_init(enabled=True): |
| 103 | + model = torch.nn.Sequential( |
| 104 | + *[ |
| 105 | + te.TransformerLayer( |
| 106 | + HIDDEN_SIZE, |
| 107 | + FFN_HIDDEN_SIZE, |
| 108 | + NUM_ATTENTION_HEADS, |
| 109 | + fuse_qkv_params=True, |
| 110 | + params_dtype=DTYPE, |
| 111 | + hidden_dropout=0.0, |
| 112 | + attention_dropout=0.0, |
| 113 | + ) |
| 114 | + for _ in range(NUM_LAYERS) |
| 115 | + ] |
| 116 | + ) |
| 117 | + |
| 118 | + # Pre-shard verification: count QuantizedTensor parameters. |
| 119 | + qt_count = sum(1 for _, p in model.named_parameters() if isinstance(p, QuantizedTensor)) |
| 120 | + assert qt_count > 0, "No QuantizedTensor parameters found" |
| 121 | + dist_print(f"Found {qt_count} QuantizedTensor (FP8) weight parameters.") |
| 122 | + |
| 123 | + # ── 3. FSDP2 sharding ──────────────────────────────────────────── |
| 124 | + custom_attrs = save_custom_attrs(model) |
| 125 | + |
| 126 | + mesh = DeviceMesh("cuda", list(range(world_size))) |
| 127 | + for child in model.children(): |
| 128 | + fully_shard(child, mesh=mesh) |
| 129 | + fully_shard(model, mesh=mesh) |
| 130 | + |
| 131 | + restore_custom_attrs(model, custom_attrs) |
| 132 | + |
| 133 | + # Post-shard verification: parameters are DTensors wrapping QuantizedTensors. |
| 134 | + for name, param in model.named_parameters(): |
| 135 | + assert isinstance(param, DTensor), f"{name} is not a DTensor after sharding" |
| 136 | + dist_print("FSDP2 sharding complete. All parameters are DTensors.") |
| 137 | + |
| 138 | + # ── 4. Optimizer ───────────────────────────────────────────────── |
| 139 | + optimizer = te.optimizers.FusedAdam( |
| 140 | + model.parameters(), |
| 141 | + lr=1e-3, |
| 142 | + master_weights=True, |
| 143 | + master_weight_dtype=torch.float32, |
| 144 | + ) |
| 145 | + dist_print("Using FusedAdam with master_weights=True.") |
| 146 | + |
| 147 | + # ── 5. Training loop ───────────────────────────────────────────── |
| 148 | + x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) |
| 149 | + target = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device) |
| 150 | + |
| 151 | + for step in range(NUM_STEPS): |
| 152 | + optimizer.zero_grad(set_to_none=True) |
| 153 | + |
| 154 | + with te.autocast(enabled=True): |
| 155 | + output = model(x) |
| 156 | + |
| 157 | + loss = F.mse_loss(output, target) |
| 158 | + loss.backward() |
| 159 | + optimizer.step() |
| 160 | + dist_print(f" Step {step}: loss = {loss.item():.6f}") |
| 161 | + |
| 162 | + # ── 6. Post-training assertions ────────────────────────────────── |
| 163 | + dist_print("\nVerifying invariants ...") |
| 164 | + |
| 165 | + qt_after = 0 |
| 166 | + for name, param in model.named_parameters(): |
| 167 | + assert isinstance(param, DTensor), f"{name} lost DTensor wrapping" |
| 168 | + if isinstance(param._local_tensor, QuantizedTensor): |
| 169 | + qt_after += 1 |
| 170 | + assert qt_after > 0, "No QuantizedTensor local tensors after training" |
| 171 | + dist_print(f" {qt_after} params still have QuantizedTensor local tensors.") |
| 172 | + |
| 173 | + # Optimizer states: master weights and moments should be float32. |
| 174 | + for param in model.parameters(): |
| 175 | + state = optimizer.state[param] |
| 176 | + if "master_param" in state: |
| 177 | + assert ( |
| 178 | + state["master_param"].dtype == torch.float32 |
| 179 | + ), f"Master weight dtype {state['master_param'].dtype}, expected float32" |
| 180 | + assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32" |
| 181 | + assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32" |
| 182 | + |
| 183 | + dist_print("All assertions passed!") |
| 184 | + dist_print(" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor") |
| 185 | + dist_print(" - Optimizer master weights: float32") |
| 186 | + dist_print(" - Optimizer states (exp_avg, exp_avg_sq): float32") |
| 187 | + |
| 188 | + dist.destroy_process_group() |
| 189 | + |
| 190 | + |
| 191 | +if __name__ == "__main__": |
| 192 | + main() |
0 commit comments