|
21 | 21 | ) |
22 | 22 | import transformer_engine.pytorch as te |
23 | 23 | from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8 |
24 | | -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor |
| 24 | +from transformer_engine.pytorch.tensor.float8_tensor import ( |
| 25 | + Float8Tensor, |
| 26 | + Float8CurrentScalingQuantizer, |
| 27 | +) |
| 28 | +from transformer_engine.pytorch.tensor.utils import replace_raw_data |
25 | 29 |
|
26 | 30 |
|
27 | 31 | def _get_raw_data(quantized_tensor): |
@@ -228,6 +232,279 @@ def step(self): |
228 | 232 | weight.data.copy_(master_weight) |
229 | 233 |
|
230 | 234 |
|
| 235 | +class MiniFSDP: |
| 236 | + def __init__(self, weights, lr, dp_group): |
| 237 | + rank = dist.get_rank(dp_group) |
| 238 | + world_size = dist.get_world_size(dp_group) |
| 239 | + |
| 240 | + self.weights = weights |
| 241 | + self.lr = lr |
| 242 | + self.dp_group = dp_group |
| 243 | + |
| 244 | + # Flatten the weights and pad to align with world size |
| 245 | + raw_data_list = [ |
| 246 | + _get_raw_data(w).view(-1) if isinstance(w, Float8Tensor) else w.view(-1) |
| 247 | + for w in weights |
| 248 | + ] |
| 249 | + if isinstance(weights[0], Float8Tensor): |
| 250 | + raw_data_list = [_get_raw_data(w).view(-1) for w in weights] |
| 251 | + else: |
| 252 | + raw_data_list = [w.view(-1) for w in weights] |
| 253 | + self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list) |
| 254 | + |
| 255 | + # Split flattened weights into shards |
| 256 | + self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] |
| 257 | + self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard) |
| 258 | + shard_size = self.flatten_weight.size(0) // world_size |
| 259 | + |
| 260 | + # Map original tensors to flattened indices |
| 261 | + tensor_indices = [] |
| 262 | + cumulative_length = 0 |
| 263 | + for tensor in raw_data_list: |
| 264 | + length = tensor.size(0) |
| 265 | + tensor_indices.append((cumulative_length, cumulative_length + length)) |
| 266 | + cumulative_length += length |
| 267 | + |
| 268 | + # Build shard index mappings |
| 269 | + self.weight_indices = [] |
| 270 | + self.shard_indices = [] |
| 271 | + for idx, (start, end) in enumerate(tensor_indices): |
| 272 | + shard_start = rank * shard_size |
| 273 | + shard_end = shard_start + shard_size |
| 274 | + adjusted_end = min(shard_end, original_length) |
| 275 | + |
| 276 | + if start <= adjusted_end and end >= shard_start: |
| 277 | + start_idx = max(start, shard_start) |
| 278 | + end_idx = min(end, adjusted_end) |
| 279 | + self.weight_indices.append((start_idx - start, end_idx - start)) |
| 280 | + self.shard_indices.append((start_idx - shard_start, end_idx - shard_start)) |
| 281 | + else: |
| 282 | + self.weight_indices.append((None, None)) |
| 283 | + self.shard_indices.append((None, None)) |
| 284 | + |
| 285 | + if isinstance(weights[idx], Float8Tensor): |
| 286 | + replace_raw_data( |
| 287 | + weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) |
| 288 | + ) |
| 289 | + else: |
| 290 | + weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape) |
| 291 | + |
| 292 | + # Initialize local model weights and high-precision master weights |
| 293 | + self.local_weights = [] |
| 294 | + self.master_weights = [] |
| 295 | + for i, weight in enumerate(self.weights): |
| 296 | + weight_start, weight_end = self.weight_indices[i] |
| 297 | + shard_start, shard_end = self.shard_indices[i] |
| 298 | + if shard_start is not None and shard_end is not None: |
| 299 | + local_weight_shard = self.local_weight_shard[shard_start:shard_end] |
| 300 | + self.local_weights.append(local_weight_shard) |
| 301 | + |
| 302 | + if isinstance(weight, QuantizedTensor): |
| 303 | + high_precision_init_val = weight.get_high_precision_init_val().view(-1) |
| 304 | + master_weight_shard = high_precision_init_val.to(weight.device).float()[ |
| 305 | + weight_start:weight_end |
| 306 | + ] |
| 307 | + else: |
| 308 | + master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end] |
| 309 | + self.master_weights.append(master_weight_shard) |
| 310 | + else: |
| 311 | + self.local_weights.append(None) |
| 312 | + self.master_weights.append(None) |
| 313 | + setattr( |
| 314 | + weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda") |
| 315 | + ) |
| 316 | + |
| 317 | + def _flatten_tensors_with_pad(self, tensors): |
| 318 | + """ |
| 319 | + Flatten the list of tensors and pad them to align with the world size. |
| 320 | +
|
| 321 | + Args: |
| 322 | + tensors (list): List of tensors to flatten. |
| 323 | +
|
| 324 | + Returns: |
| 325 | + tuple: Flattened tensor and its original length before padding. |
| 326 | + """ |
| 327 | + world_size = dist.get_world_size(self.dp_group) |
| 328 | + |
| 329 | + flatten_tensor = torch.cat(tensors) |
| 330 | + original_length = flatten_tensor.size(0) |
| 331 | + |
| 332 | + padding_needed = (world_size - original_length % world_size) % world_size |
| 333 | + if padding_needed > 0: |
| 334 | + flatten_tensor = torch.cat( |
| 335 | + [flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)] |
| 336 | + ) |
| 337 | + |
| 338 | + return flatten_tensor, original_length |
| 339 | + |
| 340 | + def zero_grad(self): |
| 341 | + for weight in self.weights: |
| 342 | + weight.grad = None |
| 343 | + weight.main_grad.zero_() |
| 344 | + |
| 345 | + def step(self): |
| 346 | + """ |
| 347 | + Perform an optimization step for the distributed sharded model. |
| 348 | +
|
| 349 | + This method includes: |
| 350 | + 1. Gradient reduce-scatter: Synchronize gradients across all processes. |
| 351 | + 2. Master weight update: Update high-precision master weights using local gradients. |
| 352 | + 3. Precision casting: Cast updated master weights to FP8 or BF16 precision. |
| 353 | + 4. Weight synchronization: All-gather updated weights across all processes. |
| 354 | +
|
| 355 | + Returns: |
| 356 | + None |
| 357 | + """ |
| 358 | + # Step 1: Reduce-scatter the gradients |
| 359 | + main_grad_buffer, _ = self._flatten_tensors_with_pad( |
| 360 | + [weight.main_grad.view(-1) for weight in self.weights] |
| 361 | + ) |
| 362 | + main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype) |
| 363 | + dist.reduce_scatter_tensor( |
| 364 | + self.local_main_grad_shard, main_grad_buffer, group=self.dp_group |
| 365 | + ) |
| 366 | + |
| 367 | + # Step 2: Update the master weights |
| 368 | + for weight, master_weight, (shard_start, shard_end) in zip( |
| 369 | + self.weights, self.master_weights, self.shard_indices |
| 370 | + ): |
| 371 | + if master_weight is None: |
| 372 | + continue |
| 373 | + |
| 374 | + # Extract the local gradient shard for this weight |
| 375 | + grad = self.local_main_grad_shard[shard_start:shard_end] |
| 376 | + |
| 377 | + # Update the master weight using gradient descent |
| 378 | + master_weight -= grad * self.lr |
| 379 | + |
| 380 | + # Step 3: Cast master weights to FP8 or BF16 precision |
| 381 | + if isinstance(self.weights[0], Float8Tensor): |
| 382 | + local_weights = [] |
| 383 | + for model_weight, local_weight in zip(self.weights, self.local_weights): |
| 384 | + if local_weight is None: |
| 385 | + local_weights.append(None) |
| 386 | + continue |
| 387 | + |
| 388 | + quantizer = model_weight._get_quantizer() |
| 389 | + if isinstance(quantizer, Float8CurrentScalingQuantizer): |
| 390 | + local_weight = quantizer.create_tensor_from_data( |
| 391 | + local_weight.view(-1), |
| 392 | + model_weight.dtype, |
| 393 | + ) |
| 394 | + local_weights.append(local_weight) |
| 395 | + |
| 396 | + cast_master_weights_to_fp8( |
| 397 | + self.weights, |
| 398 | + self.master_weights, |
| 399 | + [idx[0] for idx in self.weight_indices], |
| 400 | + self.dp_group, |
| 401 | + local_weights, |
| 402 | + ) |
| 403 | + else: |
| 404 | + for weight, master_weight in zip(self.local_weights, self.master_weights): |
| 405 | + if master_weight is None: |
| 406 | + continue |
| 407 | + |
| 408 | + # Copy updated master weights to local weights |
| 409 | + weight.data.copy_(master_weight) |
| 410 | + |
| 411 | + # Step 4: All-gather updated weights across processes |
| 412 | + dist.all_gather_into_tensor( |
| 413 | + self.flatten_weight, self.local_weight_shard, group=self.dp_group |
| 414 | + ) |
| 415 | + |
| 416 | + |
| 417 | +def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): |
| 418 | + rank = dist.get_rank(dp_group) |
| 419 | + world_size = dist.get_world_size(dp_group) |
| 420 | + |
| 421 | + # Configuration constants |
| 422 | + NUM_STEPS = 100 |
| 423 | + SEED = 12345 |
| 424 | + |
| 425 | + torch.manual_seed(SEED) |
| 426 | + torch.cuda.manual_seed(SEED) |
| 427 | + |
| 428 | + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] |
| 429 | + mock_group = mock_groups[rank] |
| 430 | + |
| 431 | + linear_kwargs = { |
| 432 | + "params_dtype": torch.bfloat16, |
| 433 | + "bias": False, |
| 434 | + "fuse_wgrad_accumulation": False, |
| 435 | + } |
| 436 | + |
| 437 | + # Create model with FP8 weights |
| 438 | + with te.fp8.fp8_model_init( |
| 439 | + enabled=quantization is not None, |
| 440 | + recipe=quantization_recipe(quantization), |
| 441 | + preserve_high_precision_init_val=True, |
| 442 | + ): |
| 443 | + model_fp8 = nn.Sequential( |
| 444 | + te.Linear(128, 256, **linear_kwargs), |
| 445 | + te.Linear(256, 256 * 3, **linear_kwargs), |
| 446 | + te.Linear(256 * 3, 128, **linear_kwargs), |
| 447 | + ) |
| 448 | + |
| 449 | + # Create model with BF16 weights |
| 450 | + model = nn.Sequential( |
| 451 | + te.Linear(128, 256, **linear_kwargs), |
| 452 | + te.Linear(256, 256 * 3, **linear_kwargs), |
| 453 | + te.Linear(256 * 3, 128, **linear_kwargs), |
| 454 | + ) |
| 455 | + |
| 456 | + # Make sure the BF16 model and FP8 model have the same initial weights |
| 457 | + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): |
| 458 | + high_precision_init_val = w_fp8.get_high_precision_init_val() |
| 459 | + w.data.copy_(high_precision_init_val) |
| 460 | + |
| 461 | + optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group) |
| 462 | + optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group) |
| 463 | + |
| 464 | + for _ in range(100): |
| 465 | + optimizer_fp8.zero_grad() |
| 466 | + optimizer.zero_grad() |
| 467 | + |
| 468 | + inputs = [ |
| 469 | + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) |
| 470 | + ] |
| 471 | + # Choose based on rank to make sure the inputs of different ranks are different. |
| 472 | + x = inputs[rank] |
| 473 | + |
| 474 | + with te.fp8.fp8_autocast( |
| 475 | + enabled=quantization is not None, |
| 476 | + fp8_recipe=quantization_recipe(quantization), |
| 477 | + fp8_group=mock_group, |
| 478 | + ): |
| 479 | + y_fp8 = model_fp8(x) |
| 480 | + |
| 481 | + with te.fp8_autocast( |
| 482 | + enabled=quantization is not None, |
| 483 | + fp8_recipe=quantization_recipe(quantization), |
| 484 | + fp8_group=mock_group, |
| 485 | + ): |
| 486 | + y = model(x) |
| 487 | + |
| 488 | + targets = [torch.randn_like(y) for _ in range(world_size)] |
| 489 | + # Choose based on rank to make sure the targets of different ranks are different. |
| 490 | + target = targets[rank] |
| 491 | + loss_fp8 = nn.MSELoss()(y_fp8, target) |
| 492 | + loss = nn.MSELoss()(y, target) |
| 493 | + |
| 494 | + loss_fp8.backward() |
| 495 | + loss.backward() |
| 496 | + |
| 497 | + optimizer_fp8.step() |
| 498 | + optimizer.step() |
| 499 | + |
| 500 | + torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) |
| 501 | + |
| 502 | + print( |
| 503 | + f"✅ Successfully validated FSDP {NUM_STEPS} training steps with" |
| 504 | + f" {quantization} quantization" |
| 505 | + ) |
| 506 | + |
| 507 | + |
231 | 508 | def _test_zero_1(dp_group): |
232 | 509 | """Make sure the implementation of zero-1 optimizer is correct""" |
233 | 510 | rank = dist.get_rank(dp_group) |
@@ -389,6 +666,7 @@ def main(argv=None, namespace=None): |
389 | 666 | dp_group = dist.new_group(backend="nccl") |
390 | 667 | _test_zero_1(dp_group) |
391 | 668 | _test_cast_master_weights_to_fp8(args.quantization, dp_group) |
| 669 | + _test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group) |
392 | 670 |
|
393 | 671 | dist.destroy_process_group() |
394 | 672 | return 0 |
|
0 commit comments