Skip to content

Commit 4d89e04

Browse files
committed
Add fused_adam, quantized_model_init, and fsdp2 example
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 57b5b60 commit 4d89e04

6 files changed

Lines changed: 841 additions & 6 deletions

File tree

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Quantized Model Initialization Examples
2+
3+
## `main.py` -- Single-GPU with Gradient Accumulation Fusion
4+
5+
Trains a single `TransformerLayer` on synthetic data while combining three
6+
Transformer Engine optimizations:
7+
8+
* **`quantized_model_init`** -- model parameters are stored in FP8, saving
9+
memory by not keeping a high-precision shadow copy.
10+
* **`FusedAdam` with master weights** -- the optimizer maintains FP32 master
11+
copies of the weights for stable training updates.
12+
* **Gradient accumulation fusion** -- weight gradients are accumulated directly
13+
in FP32 via Tensor Cores (`fuse_wgrad_accumulation=True` + `main_grad`).
14+
15+
```bash
16+
python main.py
17+
```
18+
19+
## `fully_shard.py` -- Multi-GPU with FSDP2
20+
21+
Extends the single-GPU example to multi-GPU training using PyTorch-native FSDP2
22+
(`fully_shard`). Demonstrates:
23+
24+
* **`quantized_model_init`** -- same FP8 weight initialization as `main.py`.
25+
* **`fully_shard`** -- PyTorch FSDP2 sharding of each `TransformerLayer`.
26+
* **`save_custom_attrs` / `restore_custom_attrs`** -- preserves custom
27+
Python-level attributes on `QuantizedTensor` parameters that FSDP2's
28+
`DTensor` wrapping would otherwise drop.
29+
* **`FusedAdam` with master weights** -- FP32 master copies maintained by the
30+
optimizer, with DTensor-aware state initialization.
31+
32+
```bash
33+
torchrun --nproc-per-node 2 fully_shard.py
34+
```
35+
36+
### Why `fuse_wgrad_accumulation` is not used with FSDP2
37+
38+
`fuse_wgrad_accumulation` writes weight gradients directly into a `main_grad`
39+
buffer during the wgrad GEMM and returns `None` to autograd. This bypasses
40+
FSDP2's reduce-scatter, leaving each rank with an unreduced gradient. Correct
41+
distributed training requires the per-rank gradients to be reduced across ranks,
42+
which FSDP2 handles automatically -- but only for gradients that flow through
43+
autograd.
44+
45+
Megatron-Core's FSDP integration solves this by providing `get_main_grad()`,
46+
which returns a buffer wired into its own reduce-scatter machinery. Vanilla
47+
PyTorch FSDP2 does not yet expose an equivalent hook.
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""FSDP2 distributed training with quantized model initialization.
6+
7+
Extends the single-GPU ``main.py`` example to multi-GPU training using
8+
PyTorch-native FSDP2 (``fully_shard``). The script demonstrates:
9+
10+
1. ``quantized_model_init`` -- FP8 weight initialization (same as main.py).
11+
2. ``fully_shard`` -- PyTorch FSDP2 sharding of each TransformerLayer.
12+
3. ``save_custom_attrs`` / ``restore_custom_attrs`` -- Preserve custom
13+
Python-level attributes on QuantizedTensor parameters that FSDP2's
14+
DTensor wrapping would otherwise drop.
15+
4. ``FusedAdam`` with FP32 master weights for full-precision training updates.
16+
17+
.. note::
18+
``fuse_wgrad_accumulation`` is **not** used here. That feature writes
19+
weight gradients directly into ``main_grad`` buffers, bypassing the
20+
autograd gradient flow. FSDP2 requires gradients to go through its
21+
reduce-scatter, so ``fuse_wgrad_accumulation`` needs Megatron-Core's
22+
FSDP integration (which provides ``get_main_grad()``).
23+
24+
Usage::
25+
26+
torchrun --nproc-per-node 2 fully_shard.py
27+
"""
28+
29+
import os
30+
31+
import torch
32+
import torch.distributed as dist
33+
import torch.nn.functional as F
34+
from torch.distributed._composable.fsdp import fully_shard
35+
from torch.distributed.device_mesh import DeviceMesh
36+
from torch.distributed.tensor import DTensor
37+
38+
import transformer_engine.pytorch as te
39+
from transformer_engine.pytorch import QuantizedTensor
40+
41+
# ── Configuration (matches main.py) ──────────────────────────────────
42+
HIDDEN_SIZE = 256
43+
FFN_HIDDEN_SIZE = 1024
44+
NUM_ATTENTION_HEADS = 8
45+
NUM_LAYERS = 3
46+
SEQ_LEN = 32
47+
BATCH_PER_RANK = 2
48+
NUM_STEPS = 5
49+
DTYPE = torch.bfloat16
50+
51+
52+
def dist_print(msg):
53+
"""Print only on rank 0."""
54+
if int(os.environ.get("RANK", "0")) == 0:
55+
print(msg)
56+
57+
58+
# ── Save / restore custom attributes across FSDP2 sharding ──────────
59+
# FSDP2's fully_shard replaces parameters with DTensors, which drops any
60+
# custom Python-level attributes. These helpers preserve them.
61+
# (Pattern from tests/pytorch/distributed/run_fsdp2_model.py)
62+
63+
64+
def save_custom_attrs(module):
65+
"""Save custom attributes from all parameters before FSDP2 sharding."""
66+
custom_attrs = {}
67+
for name, param in module.named_parameters():
68+
if isinstance(param, QuantizedTensor):
69+
ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")]
70+
else:
71+
ignore_keys = []
72+
attrs = vars(param)
73+
custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys}
74+
return custom_attrs
75+
76+
77+
def restore_custom_attrs(module, custom_attrs):
78+
"""Restore saved custom attributes after FSDP2 sharding."""
79+
for name, param in module.named_parameters():
80+
if name in custom_attrs:
81+
for attr_name, attr_value in custom_attrs[name].items():
82+
setattr(param, attr_name, attr_value)
83+
84+
85+
def main():
86+
# ── 1. Distributed setup ─────────────────────────────────────────
87+
assert "TORCHELASTIC_RUN_ID" in os.environ, (
88+
"This script must be launched with torchrun, e.g.:\n"
89+
" torchrun --nproc-per-node 2 fully_shard.py"
90+
)
91+
world_size = int(os.environ["WORLD_SIZE"])
92+
local_rank = int(os.environ["LOCAL_RANK"])
93+
94+
torch.cuda.set_device(local_rank)
95+
dist.init_process_group(backend="nccl")
96+
device = torch.device(f"cuda:{local_rank}")
97+
98+
torch.manual_seed(42)
99+
torch.cuda.manual_seed(42)
100+
101+
# ── 2. Create model with quantized (FP8) parameters ──────────────
102+
with te.quantized_model_init(enabled=True):
103+
model = torch.nn.Sequential(
104+
*[
105+
te.TransformerLayer(
106+
HIDDEN_SIZE,
107+
FFN_HIDDEN_SIZE,
108+
NUM_ATTENTION_HEADS,
109+
fuse_qkv_params=True,
110+
params_dtype=DTYPE,
111+
hidden_dropout=0.0,
112+
attention_dropout=0.0,
113+
)
114+
for _ in range(NUM_LAYERS)
115+
]
116+
)
117+
118+
# Pre-shard verification: count QuantizedTensor parameters.
119+
qt_count = sum(1 for _, p in model.named_parameters() if isinstance(p, QuantizedTensor))
120+
assert qt_count > 0, "No QuantizedTensor parameters found"
121+
dist_print(f"Found {qt_count} QuantizedTensor (FP8) weight parameters.")
122+
123+
# ── 3. FSDP2 sharding ────────────────────────────────────────────
124+
custom_attrs = save_custom_attrs(model)
125+
126+
mesh = DeviceMesh("cuda", list(range(world_size)))
127+
for child in model.children():
128+
fully_shard(child, mesh=mesh)
129+
fully_shard(model, mesh=mesh)
130+
131+
restore_custom_attrs(model, custom_attrs)
132+
133+
# Post-shard verification: parameters are DTensors wrapping QuantizedTensors.
134+
for name, param in model.named_parameters():
135+
assert isinstance(param, DTensor), f"{name} is not a DTensor after sharding"
136+
dist_print("FSDP2 sharding complete. All parameters are DTensors.")
137+
138+
# ── 4. Optimizer ─────────────────────────────────────────────────
139+
optimizer = te.optimizers.FusedAdam(
140+
model.parameters(),
141+
lr=1e-3,
142+
master_weights=True,
143+
master_weight_dtype=torch.float32,
144+
)
145+
dist_print("Using FusedAdam with master_weights=True.")
146+
147+
# ── 5. Training loop ─────────────────────────────────────────────
148+
x = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device)
149+
target = torch.randn(SEQ_LEN, BATCH_PER_RANK, HIDDEN_SIZE, dtype=DTYPE, device=device)
150+
151+
for step in range(NUM_STEPS):
152+
optimizer.zero_grad(set_to_none=True)
153+
154+
with te.autocast(enabled=True):
155+
output = model(x)
156+
157+
loss = F.mse_loss(output, target)
158+
loss.backward()
159+
optimizer.step()
160+
dist_print(f" Step {step}: loss = {loss.item():.6f}")
161+
162+
# ── 6. Post-training assertions ──────────────────────────────────
163+
dist_print("\nVerifying invariants ...")
164+
165+
qt_after = 0
166+
for name, param in model.named_parameters():
167+
assert isinstance(param, DTensor), f"{name} lost DTensor wrapping"
168+
if isinstance(param._local_tensor, QuantizedTensor):
169+
qt_after += 1
170+
assert qt_after > 0, "No QuantizedTensor local tensors after training"
171+
dist_print(f" {qt_after} params still have QuantizedTensor local tensors.")
172+
173+
# Optimizer states: master weights and moments should be float32.
174+
for param in model.parameters():
175+
state = optimizer.state[param]
176+
if "master_param" in state:
177+
assert (
178+
state["master_param"].dtype == torch.float32
179+
), f"Master weight dtype {state['master_param'].dtype}, expected float32"
180+
assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32"
181+
assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32"
182+
183+
dist_print("All assertions passed!")
184+
dist_print(" - Linear weight parameters: QuantizedTensor (FP8) wrapped in DTensor")
185+
dist_print(" - Optimizer master weights: float32")
186+
dist_print(" - Optimizer states (exp_avg, exp_avg_sq): float32")
187+
188+
dist.destroy_process_group()
189+
190+
191+
if __name__ == "__main__":
192+
main()
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""Quantized model initialization with FusedAdam and gradient accumulation fusion.
6+
7+
Demonstrates three Transformer Engine features working together:
8+
9+
1. ``quantized_model_init`` -- Initialize a model with low-precision (FP8)
10+
parameters, avoiding the memory cost of storing both high-precision and
11+
quantized copies of every weight.
12+
13+
2. ``FusedAdam`` with master weights -- Maintain FP32 master copies of the
14+
weights inside the optimizer so that the training update retains full
15+
precision despite the model parameters being FP8.
16+
17+
3. Gradient accumulation fusion -- Use ``fuse_wgrad_accumulation=True``
18+
together with per-parameter ``main_grad`` buffers so that weight
19+
gradients are accumulated directly in FP32 via Tensor Cores, avoiding a
20+
separate FP8-to-FP32 cast kernel.
21+
22+
Usage::
23+
24+
python main.py
25+
"""
26+
27+
import torch
28+
import transformer_engine.pytorch as te
29+
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
30+
31+
# ── Configuration ──────────────────────────────────────────────────────
32+
HIDDEN_SIZE = 256
33+
FFN_HIDDEN_SIZE = 1024
34+
NUM_ATTENTION_HEADS = 8
35+
SEQ_LEN = 32
36+
BATCH_SIZE = 2
37+
NUM_STEPS = 5
38+
DTYPE = torch.bfloat16
39+
40+
41+
def main():
42+
# ── 1. Create model with quantized parameters ─────────────────────
43+
#
44+
# Inside quantized_model_init, TransformerEngine modules store only the
45+
# FP8 quantized copy of each parameter (a Float8Tensor), eliminating the
46+
# memory overhead of a high-precision shadow copy.
47+
with te.quantized_model_init(enabled=True):
48+
model = te.TransformerLayer(
49+
HIDDEN_SIZE,
50+
FFN_HIDDEN_SIZE,
51+
NUM_ATTENTION_HEADS,
52+
fuse_wgrad_accumulation=True,
53+
fuse_qkv_params=True, # required for fuse_wgrad_accumulation
54+
params_dtype=DTYPE,
55+
hidden_dropout=0.0, # disable dropout for this synthetic example
56+
attention_dropout=0.0,
57+
)
58+
59+
# Verify that linear-layer weight parameters are quantized.
60+
# Biases and LayerNorm parameters are *not* quantized.
61+
quantized_count = 0
62+
for name, param in model.named_parameters():
63+
if isinstance(param, QuantizedTensor):
64+
quantized_count += 1
65+
assert quantized_count > 0, "No QuantizedTensor parameters found"
66+
print(f"Found {quantized_count} QuantizedTensor (FP8) weight parameters.")
67+
68+
# ── 2. Allocate main_grad buffers (FP32) ──────────────────────────
69+
#
70+
# fuse_wgrad_accumulation causes weight-gradient GEMMs to write directly
71+
# into ``param.main_grad`` in FP32 (via Tensor Core accumulation).
72+
# Non-weight parameters (e.g. LayerNorm) still receive gradients through
73+
# the normal ``param.grad`` path.
74+
for param in model.parameters():
75+
param.main_grad = torch.zeros(param.shape, dtype=torch.float32, device=param.device)
76+
77+
# ── 3. Optimizer with FP32 master weights ─────────────────────────
78+
#
79+
# use_decoupled_grad=True tells FusedAdam to read gradients from
80+
# ``param.decoupled_grad`` instead of ``param.grad``. This avoids
81+
# the dtype-mismatch error that would occur when assigning FP32
82+
# gradients to bfloat16 parameters via ``.grad``.
83+
optimizer = te.optimizers.FusedAdam(
84+
model.parameters(),
85+
lr=1e-3,
86+
master_weights=True,
87+
master_weight_dtype=torch.float32,
88+
use_decoupled_grad=True,
89+
)
90+
91+
# ── 4. Training loop ──────────────────────────────────────────────
92+
#
93+
# Use a fixed synthetic dataset so that loss decreases over steps.
94+
x = torch.randn(SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE, device="cuda")
95+
target = torch.randn(SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE, device="cuda")
96+
97+
for step in range(NUM_STEPS):
98+
optimizer.zero_grad(set_to_none=True)
99+
for param in model.parameters():
100+
param.main_grad.zero_()
101+
102+
# Forward pass inside autocast to enable FP8 compute.
103+
with te.autocast(enabled=True):
104+
output = model(x)
105+
106+
loss = torch.nn.functional.mse_loss(output, target)
107+
loss.backward()
108+
109+
# Consolidate gradients into main_grad.
110+
# * Weight params with fuse_wgrad_accumulation: backward already
111+
# accumulated the gradient directly into main_grad (FP32).
112+
# * Other params (e.g. LayerNorm): autograd set param.grad.
113+
for param in model.parameters():
114+
if param.grad is not None:
115+
param.main_grad.copy_(param.grad)
116+
param.grad = None
117+
118+
# Expose main_grad as decoupled_grad so FusedAdam can read it.
119+
for param in model.parameters():
120+
param.decoupled_grad = param.main_grad
121+
122+
optimizer.step()
123+
print(f" Step {step}: loss = {loss.item():.6f}")
124+
125+
# ── 5. Post-training assertions ───────────────────────────────────
126+
print("\nVerifying invariants ...")
127+
128+
# Optimizer states.
129+
for param in model.parameters():
130+
state = optimizer.state[param]
131+
if "master_param" in state:
132+
master = state["master_param"]
133+
assert (
134+
master.dtype == torch.float32
135+
), f"Master weight dtype {master.dtype}, expected float32"
136+
assert state["exp_avg"].dtype == torch.float32, "exp_avg should be float32"
137+
assert state["exp_avg_sq"].dtype == torch.float32, "exp_avg_sq should be float32"
138+
139+
# main_grad buffers.
140+
for param in model.parameters():
141+
assert param.main_grad.dtype == torch.float32, "main_grad should be float32"
142+
143+
print("All assertions passed!")
144+
print(" - Linear weight parameters: QuantizedTensor (FP8)")
145+
print(" - Optimizer master weights: float32")
146+
print(" - Optimizer states (exp_avg, exp_avg_sq): float32")
147+
print(" - Gradient accumulation buffers (main_grad): float32")
148+
149+
150+
if __name__ == "__main__":
151+
main()

0 commit comments

Comments
 (0)