Skip to content

Commit c84d170

Browse files
shjwudpksivaman
andauthored
Support FP8 primary weight in FSDP training (NVIDIA#1630)
Support fp8 primary weight in fsdp training Signed-off-by: jianbinc <shjwudp@gmail.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent a3ba4df commit c84d170

2 files changed

Lines changed: 318 additions & 15 deletions

File tree

tests/pytorch/distributed/run_cast_master_weights_to_fp8.py

Lines changed: 279 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
)
2222
import transformer_engine.pytorch as te
2323
from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8
24-
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
24+
from transformer_engine.pytorch.tensor.float8_tensor import (
25+
Float8Tensor,
26+
Float8CurrentScalingQuantizer,
27+
)
28+
from transformer_engine.pytorch.tensor.utils import replace_raw_data
2529

2630

2731
def _get_raw_data(quantized_tensor):
@@ -228,6 +232,279 @@ def step(self):
228232
weight.data.copy_(master_weight)
229233

230234

235+
class MiniFSDP:
236+
def __init__(self, weights, lr, dp_group):
237+
rank = dist.get_rank(dp_group)
238+
world_size = dist.get_world_size(dp_group)
239+
240+
self.weights = weights
241+
self.lr = lr
242+
self.dp_group = dp_group
243+
244+
# Flatten the weights and pad to align with world size
245+
raw_data_list = [
246+
_get_raw_data(w).view(-1) if isinstance(w, Float8Tensor) else w.view(-1)
247+
for w in weights
248+
]
249+
if isinstance(weights[0], Float8Tensor):
250+
raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
251+
else:
252+
raw_data_list = [w.view(-1) for w in weights]
253+
self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list)
254+
255+
# Split flattened weights into shards
256+
self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
257+
self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard)
258+
shard_size = self.flatten_weight.size(0) // world_size
259+
260+
# Map original tensors to flattened indices
261+
tensor_indices = []
262+
cumulative_length = 0
263+
for tensor in raw_data_list:
264+
length = tensor.size(0)
265+
tensor_indices.append((cumulative_length, cumulative_length + length))
266+
cumulative_length += length
267+
268+
# Build shard index mappings
269+
self.weight_indices = []
270+
self.shard_indices = []
271+
for idx, (start, end) in enumerate(tensor_indices):
272+
shard_start = rank * shard_size
273+
shard_end = shard_start + shard_size
274+
adjusted_end = min(shard_end, original_length)
275+
276+
if start <= adjusted_end and end >= shard_start:
277+
start_idx = max(start, shard_start)
278+
end_idx = min(end, adjusted_end)
279+
self.weight_indices.append((start_idx - start, end_idx - start))
280+
self.shard_indices.append((start_idx - shard_start, end_idx - shard_start))
281+
else:
282+
self.weight_indices.append((None, None))
283+
self.shard_indices.append((None, None))
284+
285+
if isinstance(weights[idx], Float8Tensor):
286+
replace_raw_data(
287+
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
288+
)
289+
else:
290+
weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape)
291+
292+
# Initialize local model weights and high-precision master weights
293+
self.local_weights = []
294+
self.master_weights = []
295+
for i, weight in enumerate(self.weights):
296+
weight_start, weight_end = self.weight_indices[i]
297+
shard_start, shard_end = self.shard_indices[i]
298+
if shard_start is not None and shard_end is not None:
299+
local_weight_shard = self.local_weight_shard[shard_start:shard_end]
300+
self.local_weights.append(local_weight_shard)
301+
302+
if isinstance(weight, QuantizedTensor):
303+
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
304+
master_weight_shard = high_precision_init_val.to(weight.device).float()[
305+
weight_start:weight_end
306+
]
307+
else:
308+
master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end]
309+
self.master_weights.append(master_weight_shard)
310+
else:
311+
self.local_weights.append(None)
312+
self.master_weights.append(None)
313+
setattr(
314+
weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda")
315+
)
316+
317+
def _flatten_tensors_with_pad(self, tensors):
318+
"""
319+
Flatten the list of tensors and pad them to align with the world size.
320+
321+
Args:
322+
tensors (list): List of tensors to flatten.
323+
324+
Returns:
325+
tuple: Flattened tensor and its original length before padding.
326+
"""
327+
world_size = dist.get_world_size(self.dp_group)
328+
329+
flatten_tensor = torch.cat(tensors)
330+
original_length = flatten_tensor.size(0)
331+
332+
padding_needed = (world_size - original_length % world_size) % world_size
333+
if padding_needed > 0:
334+
flatten_tensor = torch.cat(
335+
[flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)]
336+
)
337+
338+
return flatten_tensor, original_length
339+
340+
def zero_grad(self):
341+
for weight in self.weights:
342+
weight.grad = None
343+
weight.main_grad.zero_()
344+
345+
def step(self):
346+
"""
347+
Perform an optimization step for the distributed sharded model.
348+
349+
This method includes:
350+
1. Gradient reduce-scatter: Synchronize gradients across all processes.
351+
2. Master weight update: Update high-precision master weights using local gradients.
352+
3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
353+
4. Weight synchronization: All-gather updated weights across all processes.
354+
355+
Returns:
356+
None
357+
"""
358+
# Step 1: Reduce-scatter the gradients
359+
main_grad_buffer, _ = self._flatten_tensors_with_pad(
360+
[weight.main_grad.view(-1) for weight in self.weights]
361+
)
362+
main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype)
363+
dist.reduce_scatter_tensor(
364+
self.local_main_grad_shard, main_grad_buffer, group=self.dp_group
365+
)
366+
367+
# Step 2: Update the master weights
368+
for weight, master_weight, (shard_start, shard_end) in zip(
369+
self.weights, self.master_weights, self.shard_indices
370+
):
371+
if master_weight is None:
372+
continue
373+
374+
# Extract the local gradient shard for this weight
375+
grad = self.local_main_grad_shard[shard_start:shard_end]
376+
377+
# Update the master weight using gradient descent
378+
master_weight -= grad * self.lr
379+
380+
# Step 3: Cast master weights to FP8 or BF16 precision
381+
if isinstance(self.weights[0], Float8Tensor):
382+
local_weights = []
383+
for model_weight, local_weight in zip(self.weights, self.local_weights):
384+
if local_weight is None:
385+
local_weights.append(None)
386+
continue
387+
388+
quantizer = model_weight._get_quantizer()
389+
if isinstance(quantizer, Float8CurrentScalingQuantizer):
390+
local_weight = quantizer.create_tensor_from_data(
391+
local_weight.view(-1),
392+
model_weight.dtype,
393+
)
394+
local_weights.append(local_weight)
395+
396+
cast_master_weights_to_fp8(
397+
self.weights,
398+
self.master_weights,
399+
[idx[0] for idx in self.weight_indices],
400+
self.dp_group,
401+
local_weights,
402+
)
403+
else:
404+
for weight, master_weight in zip(self.local_weights, self.master_weights):
405+
if master_weight is None:
406+
continue
407+
408+
# Copy updated master weights to local weights
409+
weight.data.copy_(master_weight)
410+
411+
# Step 4: All-gather updated weights across processes
412+
dist.all_gather_into_tensor(
413+
self.flatten_weight, self.local_weight_shard, group=self.dp_group
414+
)
415+
416+
417+
def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
418+
rank = dist.get_rank(dp_group)
419+
world_size = dist.get_world_size(dp_group)
420+
421+
# Configuration constants
422+
NUM_STEPS = 100
423+
SEED = 12345
424+
425+
torch.manual_seed(SEED)
426+
torch.cuda.manual_seed(SEED)
427+
428+
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
429+
mock_group = mock_groups[rank]
430+
431+
linear_kwargs = {
432+
"params_dtype": torch.bfloat16,
433+
"bias": False,
434+
"fuse_wgrad_accumulation": False,
435+
}
436+
437+
# Create model with FP8 weights
438+
with te.fp8.fp8_model_init(
439+
enabled=quantization is not None,
440+
recipe=quantization_recipe(quantization),
441+
preserve_high_precision_init_val=True,
442+
):
443+
model_fp8 = nn.Sequential(
444+
te.Linear(128, 256, **linear_kwargs),
445+
te.Linear(256, 256 * 3, **linear_kwargs),
446+
te.Linear(256 * 3, 128, **linear_kwargs),
447+
)
448+
449+
# Create model with BF16 weights
450+
model = nn.Sequential(
451+
te.Linear(128, 256, **linear_kwargs),
452+
te.Linear(256, 256 * 3, **linear_kwargs),
453+
te.Linear(256 * 3, 128, **linear_kwargs),
454+
)
455+
456+
# Make sure the BF16 model and FP8 model have the same initial weights
457+
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
458+
high_precision_init_val = w_fp8.get_high_precision_init_val()
459+
w.data.copy_(high_precision_init_val)
460+
461+
optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group)
462+
optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group)
463+
464+
for _ in range(100):
465+
optimizer_fp8.zero_grad()
466+
optimizer.zero_grad()
467+
468+
inputs = [
469+
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
470+
]
471+
# Choose based on rank to make sure the inputs of different ranks are different.
472+
x = inputs[rank]
473+
474+
with te.fp8.fp8_autocast(
475+
enabled=quantization is not None,
476+
fp8_recipe=quantization_recipe(quantization),
477+
fp8_group=mock_group,
478+
):
479+
y_fp8 = model_fp8(x)
480+
481+
with te.fp8_autocast(
482+
enabled=quantization is not None,
483+
fp8_recipe=quantization_recipe(quantization),
484+
fp8_group=mock_group,
485+
):
486+
y = model(x)
487+
488+
targets = [torch.randn_like(y) for _ in range(world_size)]
489+
# Choose based on rank to make sure the targets of different ranks are different.
490+
target = targets[rank]
491+
loss_fp8 = nn.MSELoss()(y_fp8, target)
492+
loss = nn.MSELoss()(y, target)
493+
494+
loss_fp8.backward()
495+
loss.backward()
496+
497+
optimizer_fp8.step()
498+
optimizer.step()
499+
500+
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
501+
502+
print(
503+
f"✅ Successfully validated FSDP {NUM_STEPS} training steps with"
504+
f" {quantization} quantization"
505+
)
506+
507+
231508
def _test_zero_1(dp_group):
232509
"""Make sure the implementation of zero-1 optimizer is correct"""
233510
rank = dist.get_rank(dp_group)
@@ -389,6 +666,7 @@ def main(argv=None, namespace=None):
389666
dp_group = dist.new_group(backend="nccl")
390667
_test_zero_1(dp_group)
391668
_test_cast_master_weights_to_fp8(args.quantization, dp_group)
669+
_test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group)
392670

393671
dist.destroy_process_group()
394672
return 0

0 commit comments

Comments
 (0)