Skip to content

Commit eb8606a

Browse files
committed
Expand fsdp2 test suite, add example for FusedAdam w/ and w/o fully_shard
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 842b770 commit eb8606a

11 files changed

Lines changed: 1365 additions & 65 deletions

File tree

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""FSDP2 distributed training with quantized model initialization.
6+
7+
Extends the single-GPU ``main.py`` example to multi-GPU training using
8+
PyTorch-native FSDP2 (``fully_shard``). The script demonstrates:
9+
10+
1. **Meta-device initialization** -- Model parameters are created on the
11+
``meta`` device (zero memory), then FSDP2 sharding is applied, and
12+
finally ``reset_parameters()`` materializes and quantizes only the
13+
local shards on each rank's GPU.
14+
2. ``quantized_model_init`` -- Flags the model for FP8 weight initialization
15+
(actual quantization happens in ``reset_parameters`` after sharding).
16+
3. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer.
17+
4. ``FusedAdam`` with FP32 master weights for full-precision training updates.
18+
19+
.. note::
20+
``fuse_wgrad_accumulation`` is **not** used here. That feature writes
21+
weight gradients directly into ``main_grad`` buffers, bypassing the
22+
autograd gradient flow. FSDP2 requires gradients to go through its
23+
reduce-scatter, so ``fuse_wgrad_accumulation`` needs Megatron-Core's
24+
FSDP integration (which provides ``get_main_grad()``).
25+
26+
Usage::
27+
28+
torchrun --nproc-per-node 2 fully_shard.py
29+
"""
30+
31+
import os
32+
33+
import torch
34+
import torch.distributed as dist
35+
import torch.nn.functional as F
36+
from torch.distributed._composable.fsdp import fully_shard
37+
from torch.distributed.device_mesh import DeviceMesh
38+
from torch.distributed.tensor import DTensor
39+
40+
import transformer_engine.pytorch as te
41+
from transformer_engine.pytorch import QuantizedTensor
42+
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
43+
44+
# ── Configuration (matches main.py) ──────────────────────────────────
45+
HIDDEN_SIZE = 256
46+
FFN_HIDDEN_SIZE = 1024
47+
NUM_ATTENTION_HEADS = 8
48+
NUM_LAYERS = 3
49+
SEQ_LEN = 32
50+
BATCH_PER_RANK = 2
51+
NUM_STEPS = 5
52+
DTYPE = torch.bfloat16
53+
54+
55+
def dist_print(msg):
56+
"""Print only on rank 0."""
57+
if int(os.environ.get("RANK", "0")) == 0:
58+
print(msg)
59+
60+
61+
def main():
62+
# ── 1. Distributed setup ─────────────────────────────────────────
63+
assert "TORCHELASTIC_RUN_ID" in os.environ, (
64+
"This script must be launched with torchrun, e.g.:\n"
65+
" torchrun --nproc-per-node 2 fully_shard.py"
66+
)
67+
world_size = int(os.environ["WORLD_SIZE"])
68+
local_rank = int(os.environ["LOCAL_RANK"])
69+
70+
torch.cuda.set_device(local_rank)
71+
dist.init_process_group(backend="nccl")
72+
device = torch.device(f"cuda:{local_rank}")
73+
74+
torch.manual_seed(42)
75+
torch.cuda.manual_seed(42)
76+
77+
# ── 2. Create model on meta device (zero memory) ────────────────
78+
# quantized_model_init sets the flag for FP8 weight initialization,
79+
# but with device="meta" no actual memory is allocated yet.
80+
with te.quantized_model_init(enabled=True):
81+
model = torch.nn.Sequential(
82+
*[
83+
te.TransformerLayer(
84+
HIDDEN_SIZE,
85+
FFN_HIDDEN_SIZE,
86+
NUM_ATTENTION_HEADS,
87+
fuse_qkv_params=True,
88+
params_dtype=DTYPE,
89+
hidden_dropout=0.0,
90+
attention_dropout=0.0,
91+
device="meta",
92+
)
93+
for _ in range(NUM_LAYERS)
94+
]
95+
)
96+
97+
# Verify all parameters are on meta device (no GPU memory used).
98+
for name, param in model.named_parameters():
99+
assert param.device == torch.device("meta"), f"{name} is not on meta device"
100+
dist_print("Model created on meta device (zero GPU memory).")
101+
102+
# ── 3. FSDP2 sharding ────────────────────────────────────────────
103+
# Apply sharding to the meta-device model. FSDP2 wraps parameters
104+
# as DTensors but no GPU memory is allocated yet.
105+
mesh = DeviceMesh("cuda", list(range(world_size)))
106+
for child in model.children():
107+
fully_shard(child, mesh=mesh)
108+
fully_shard(model, mesh=mesh)
109+
dist_print("FSDP2 sharding applied to meta-device model.")
110+
111+
# ── 4. Materialize parameters on GPU ──────────────────────────────
112+
# reset_parameters() on each TE module materializes the local shard
113+
# on CUDA, applies weight initialization, and quantizes to FP8.
114+
for module in model.modules():
115+
if isinstance(module, TransformerEngineBaseModule):
116+
module.reset_parameters()
117+
118+
# Post-materialization verification.
119+
for name, param in model.named_parameters():
120+
assert isinstance(param, DTensor), f"{name} is not a DTensor after sharding"
121+
qt_count = sum(
122+
1
123+
for _, p in model.named_parameters()
124+
if isinstance(p, DTensor) and isinstance(p._local_tensor, QuantizedTensor)
125+
)
126+
assert qt_count > 0, "No QuantizedTensor local tensors after materialization"
127+
dist_print(
128+
f"Parameters materialized: {qt_count} FP8 (QuantizedTensor) weight params "
129+
"wrapped in DTensors."
130+
)
131+
132+
# ── 5. Optimizer ─────────────────────────────────────────────────
133+
optimizer = te.optimizers.FusedAdam(
134+
model.parameters(),
135+
lr=1e-3,
136+
master_weights=True,
137+
master_weight_dtype=torch.float32,
138+
)
139+
dist_print("Using FusedAdam with master_weights=True.")
140+
141+
# ── 6. Training loop ─────────────────────────────────────────────
142+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device)
143+
target = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device)
144+
145+
for step in range(NUM_STEPS):
146+
optimizer.zero_grad(set_to_none=True)
147+
148+
with te.autocast(enabled=True):
149+
output = model(x)
150+
151+
loss = F.mse_loss(output, target)
152+
loss.backward()
153+
optimizer.step()
154+
dist_print(f" Step {step}: loss = {loss.item():.6f}")
155+
156+
# ── 7. Post-training assertions ──────────────────────────────────
157+
dist_print("\nVerifying invariants ...")
158+
159+
qt_after = 0
160+
for name, param in model.named_parameters():
161+
assert isinstance(param, DTensor), f"{name} lost DTensor wrapping"
162+
if isinstance(param._local_tensor, QuantizedTensor):
163+
qt_after += 1
164+
assert qt_after > 0, "No QuantizedTensor local tensors after training"
165+
dist_print(f" {qt_after} params still have QuantizedTensor local tensors.")
166+
167+
# Optimizer states: master weights and moments should be float32.
168+
for param in model.parameters():
169+
state = optimizer.state[param]
170+
if "master_param" in state:
171+
assert (
172+
state["master_param"].dtype == torch.float32
173+
), f"Master weight dtype {state['master_param'].dtype}, expected float32"
174+
assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32"
175+
assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32"
176+
177+
dist_print("All assertions passed!")
178+
dist_print(" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor")
179+
dist_print(" - Optimizer master weights: float32")
180+
dist_print(" - Optimizer states (exp_avg, exp_avg_sq): float32")
181+
182+
# ── 8. Distributed checkpoint: save and load ─────────────────────
183+
# torch.distributed.checkpoint (DCP) saves sharded state — each rank
184+
# writes only its local shard. This preserves FP8 compute weights
185+
# and the full optimizer state (master weights, moments, step count).
186+
import torch.distributed.checkpoint as dcp
187+
from torch.distributed.checkpoint.state_dict import (
188+
StateDictOptions,
189+
get_model_state_dict,
190+
get_optimizer_state_dict,
191+
)
192+
193+
# Use a fixed path so all ranks agree on the checkpoint location.
194+
checkpoint_dir = "/tmp/te_fsdp2_example_checkpoint"
195+
dist_print(f"\nSaving distributed checkpoint to {checkpoint_dir} ...")
196+
197+
# Save sharded checkpoint. DCP handles DTensor shards natively —
198+
# each rank writes only its local shard to the filesystem.
199+
dcp.save(
200+
{"model": model.state_dict(), "optimizer": optimizer.state_dict()},
201+
checkpoint_id=checkpoint_dir,
202+
)
203+
dist_print(" Checkpoint saved (FP8 weights + optimizer state).")
204+
205+
# Load checkpoint back. Provide empty state dict containers with the
206+
# same structure; DCP fills them from the saved files.
207+
state_to_load = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
208+
dcp.load(state_to_load, checkpoint_id=checkpoint_dir)
209+
model.load_state_dict(state_to_load["model"])
210+
optimizer.load_state_dict(state_to_load["optimizer"])
211+
dist_print(" Checkpoint loaded — FP8 weights and optimizer state restored.")
212+
213+
# Verify training continues after checkpoint load.
214+
optimizer.zero_grad(set_to_none=True)
215+
with te.autocast(enabled=True):
216+
output = model(x)
217+
loss = F.mse_loss(output, target)
218+
loss.backward()
219+
optimizer.step()
220+
dist_print(f" Post-checkpoint training step: loss = {loss.item():.6f}")
221+
222+
# ── 9. Save full-precision (FP32) model to safetensors ───────────
223+
# For inference or fine-tuning you typically want FP32 weights, not
224+
# FP8 compute weights. The optimizer's master weight copies are the
225+
# authoritative FP32 values (more precise than dequantizing FP8).
226+
# All ranks must participate in gathering; only rank 0 saves.
227+
from safetensors.torch import save_file
228+
229+
full_opts = StateDictOptions(full_state_dict=True, cpu_offload=True)
230+
231+
full_model_state = get_model_state_dict(model, options=full_opts)
232+
full_opt_state = get_optimizer_state_dict(model, optimizer, options=full_opts)
233+
234+
rank = int(os.environ.get("RANK", "0"))
235+
if rank == 0:
236+
fp32_state = {}
237+
opt_param_states = full_opt_state.get("state", {})
238+
239+
for key, value in full_model_state.items():
240+
if key in opt_param_states and "master_param" in opt_param_states[key]:
241+
# Prefer optimizer's FP32 master weight (maintained throughout training).
242+
fp32_state[key] = opt_param_states[key]["master_param"].float()
243+
elif isinstance(value, QuantizedTensor):
244+
# Fallback: dequantize FP8 → FP32 (e.g. if master_weights was off).
245+
fp32_state[key] = value.dequantize().float()
246+
else:
247+
# Non-FP8 params (e.g. LayerNorm weights): cast to FP32.
248+
fp32_state[key] = value.float()
249+
250+
save_path = "/tmp/te_fsdp2_example_model_fp32.safetensors"
251+
save_file(fp32_state, save_path)
252+
dist_print(f"\nSaved FP32 model ({len(fp32_state)} params) to {save_path}")
253+
254+
# Quick verification: all saved tensors are float32.
255+
from safetensors.torch import load_file
256+
257+
loaded = load_file(save_path)
258+
for k, v in loaded.items():
259+
assert v.dtype == torch.float32, f"{k}: expected float32, got {v.dtype}"
260+
dist_print(f" Verified: all {len(loaded)} tensors are float32.")
261+
262+
dist.destroy_process_group()
263+
264+
265+
if __name__ == "__main__":
266+
main()

0 commit comments

Comments
 (0)