Conversation
…g and refined test case generation for various configurations. - Cleaned up unused variables and improved code readability in the FSDPAGTensor class by removing unnecessary parameters.
… FusedAdam. Added debug print for DTensor in MultiTensorApply.
… tolerances for tensor comparisons. Updated test logic to accommodate new tolerance parameters for improved accuracy in floating-point comparisons.
…l differences in gradient calculations. Clean up unused debug print statements in MultiTensorApply and ensure proper newline at the end of the FSDPAGTensor serialization method.
| if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache: | ||
| quantizer = module.quantizers["scaling_fwd"][self._fp8_meta_index] | ||
| if not isinstance(quantizer, MXFP8Quantizer): | ||
| quantizer.set_usage(columnwise=False) |
There was a problem hiding this comment.
For FSDP2 with FP8, keep_fp8_weight_transpose_cache should be False. Caching the transposed weight would imply an all-gather of the transposed tensor as well, increasing memory and communication and negating the advantages of FSDP2’s sharded parameter layout.
| data = torch.zeros_like(param, dtype=torch.int16) | ||
| else: | ||
| data = torch.empty(param.shape, dtype=dtype, device=param.device) | ||
| data = torch.empty_like(param, dtype=dtype) |
There was a problem hiding this comment.
When using FSDP2, parameters are DTensors, and when we do torch.zeros() or torch.empty() we create regular pytorch Tensors.
This was causing
[rank1]: RuntimeError: aten.copy_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
[rank7]: File "/workspace/TransformerEngine/transformer_engine/pytorch/optimizers/fused_adam.py", line 422, in initialize_state
[rank7]: self.set_scaled_state(param, "master_param", param.clone().detach().float())
[rank7]: File "/workspace/TransformerEngine/transformer_engine/pytorch/optimizers/fused_adam.py", line 363, in set_scaled_state
[rank7]: state[state_name].copy_(unscaled_state)
Fix:
Keep optimizer state consistent with the parameter type: when parameters are DTensors, state should be DTensors as well. Using torch.empty_like(param, ...) (and the same idea for other state buffers) creates state as a DTensor with the same placement as param, so both sides of copy_ are DTensors and the error is avoided.
There was a problem hiding this comment.
Is it upstream fix cherry-picking?
There was a problem hiding this comment.
Upstream fixes this in TEv2.12, along with few other fixes.
NVIDIA/TransformerEngine@fe8fad5#diff-0801a8d92a56d458946da1439b62e0add1613b7da83d31bc218a852b6b9e42b1
This wasn't cherry picked.
…by adding a newline character after the pass statement in the test_dummy function.
|
|
||
| # Zero the parameter gradients | ||
| optimizer.zero_grad() | ||
| with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): |
There was a problem hiding this comment.
Does with te.fp8_autocast(enabled=args.fp8_autocast,.. ) do the same?
There was a problem hiding this comment.
It does do the same but since with TEv2.10, te.fp8_autocast is replaced with te.autocast, I've made the change to be consistent.
There was a problem hiding this comment.
So will 'with te.autocast(enabled=args.fp8_autocast, recipe=...)' do the same as if/else?
There was a problem hiding this comment.
Yes, it should. I'll make the changes.
| assert len(l1) == len(l2), "Unequal number of outputs." | ||
| for i, (t1, t2) in enumerate(zip(l1, l2)): | ||
| result = torch.allclose(t1, t2, atol=0, rtol=0) | ||
| tols = dict(atol=atol) |
There was a problem hiding this comment.
Move tolls calculation out of the loop
…s for improved clarity and consistency.
Manually ported fix from upstream commit 139c863 The full commit was not cherry-picked due to unrelated changes across many files. Addressed PR comments
| # scales and produce bit-identical FP8 GEMMs — strict tolerance (0) is used. | ||
| if quantized_init or (not quantized_init and not autocast): | ||
| atol = 1e-6 | ||
| rtol = 5e-5 |
There was a problem hiding this comment.
If our reference is ddp with the same fp8 primary weight, then the same cast from fp32 master weight to fp8 happens in both target and reference flow. Then we will have exact match?
Description
This PR adds unit test covering different configurations such as:
All the unit tests compare FSDP2 vs DDP grads/output.
This PR also cleans up fsdp2_all_gather_tensor to match upstream's methods.
This PR also fixes issue with fused_adam when using it with FSDP2.
Fixes # (https://github.com/ROCm/frameworks-internal/issues/15291)
Type of change
Checklist: