Conversation
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Greptile SummaryThis PR implements NVFP4 partial cast infrastructure for distributed training with ZeRO/FSDP optimizers. The implementation enables efficient conversion of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks. Key Changes
Issues from Previous ReviewPrevious review comments have been addressed:
TestingComprehensive test coverage including transpose correctness, multi-GPU partial cast validation, single-GPU equivalence testing, and 500-iteration training loop verification Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[FP32 Master Weight Shards] --> B[Batch Convert to BF16/FP16]
B --> C[Compute Partial Amax per Block]
C --> D[Compute Global Amax per Param]
D --> E[AllReduce: Max Across DP Ranks]
E --> F[Compute Global Scale]
F --> G[Fused Scale Kernel]
G --> H[Compute Per-Block Decode Scale]
G --> I[Expand to Row-Level FP8 E4M3]
G --> J[Copy Global Amax to Target]
H --> K[Partial Cast: FP32 to NVFP4]
I --> K
F --> K
K --> L[NVFP4 Rowwise Data Nibble-Packed]
L --> M{AllGather Needed?}
M -->|Yes| N[AllGather Model Weights]
M -->|No| O[Post Processing]
N --> O
O --> P[NVFP4 Data Transpose]
O --> Q[NVFP4 Scale Transpose]
P --> R[NVFP4 Columnwise Data]
Q --> S[NVFP4 Columnwise Scale]
R --> T[Ready for GEMM]
S --> T
Last reviewed commit: c29d523 |
This comment was marked as outdated.
This comment was marked as outdated.
| start_offsets, | ||
| group, | ||
| fsdp_shard_model_weights=None, | ||
| manual_post_all_gather_processing=False, |
There was a problem hiding this comment.
We added this kwarg to the FP8 functions for backward compatibility, but there's no point keeping them for these brand-new NVFP4 APIs:
| manual_post_all_gather_processing=False, |
There was a problem hiding this comment.
fsdp_shard_model_weights=None is for future FSDP support. It's in the plan.
manual_post_all_gather_processing is also needed for the same reason as FP8 blockwise scaling:
https://github.com/WanZzzzzz/TransformerEngine/blob/38b92b1a168dcfaa6242fea50f03e5a1b873e3a0/transformer_engine/pytorch/tensor/utils.py#L535
There was a problem hiding this comment.
I see, that makes sense for now then. Let's change the default to True though since that's preferred.
I want to flag a potential future problem with manual_post_all_gather_processing=False: it assumes that the quantized tensor has some way to handle the post-processing automatically. For FP8 on Hopper:
cast_master_weights_to_fp8(..., manual_post_all_gather_processing=False)
torch.all_gather(...)
y = model(x) # Float8Tensor internally performs FP8 transposeThis is not something TE will guarantee for future data formats. Maybe the next recipe has some interleaved format:
cast_master_weights_to_futureformat(...)
torch.all_gather(...)
fix_futureformat_interleaving(...)
y = model(x) # FutureFormatTensor assumes data is interleavedIn this case, we should throw an error with the user passes manual_post_all_gather_processing=False and it should be Mcore's responsibility to perform the post-processing in a way that's friendly to overlapping.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
Outdated
Show resolved
Hide resolved
| if isinstance(self.weights[0], QuantizedTensor): | ||
| weight_buffer_dtype = torch.uint8 | ||
| if self.weights_are_nvfp4: | ||
| weight_buffer_length = self.storage_total | ||
| buffer_rank_start = storage_rank_start | ||
| buffer_rank_end = storage_rank_end | ||
| else: | ||
| weight_buffer_length = self.offsets[-1] | ||
| buffer_rank_start = rank_start | ||
| buffer_rank_end = rank_end | ||
| else: | ||
| weight_buffer_dtype = weights[0].dtype | ||
| weight_buffer_length = self.offsets[-1] | ||
| buffer_rank_start = rank_start | ||
| buffer_rank_end = rank_end |
There was a problem hiding this comment.
Nit: It's a bit convoluted, isn't it? It would be much nicer to disentangle the quantization logic from the buffer allocation by computing storage offsets in all cases (even if it's trivial for non-NVFP4 cases) and then using that blindly here.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
This comment was marked as resolved.
This comment was marked as resolved.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
|
/te-ci L1 |
timmoon10
left a comment
There was a problem hiding this comment.
Overall LGTM, although there are some test failures related to missing licenses and linter warnings. I also still have some nits, although they are not blocking.
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
for more information, see https://pre-commit.ci
| continue; | ||
| } | ||
| const size_t ref_idx = mask[first] ? elem_idx[first] : elem_idx[second]; | ||
| const size_t byte_idx = (ref_idx - start_offset) >> 1; |
There was a problem hiding this comment.
Byte index calculation incorrect for odd start_offset values
When start_offset is odd, this formula produces wrong byte indices. For example:
- Element at index 2 with
start_offset=1:(2-1)>>1 = 0(wrong, should be 1) - Element at index 3 with
start_offset=1:(3-1)>>1 = 1(correct)
Should be: byte_idx = (ref_idx >> 1) - (start_offset >> 1)
Current tests only use even start_offset values (multiples of shard_size), so this bug isn't caught. While typical ZeRO/FSDP sharding uses even boundaries, irregular sharding (e.g., expert parallelism) could trigger this.
This comment was marked as outdated.
This comment was marked as outdated.
|
/te-ci pytorch L1 |
Description
This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.
Type of change
Changes
This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:
NVFP4 Partial Cast Kernel (
nvfp4_2d_partial_cast)NVFP4 Transpose Kernel (
nvfp4_transpose)uint2loads/stores with 64×64 tiles for efficient memory accessFused Scale Kernel (
nvfp4_fused_scale)Multi-Tensor Dispatch Pattern
CPU Overhead Optimizations
torch.cat/torch.splittorch.zeros()withtorch.empty()for immediately written buffersScale Computation Improvements
New Public API
cast_master_weights_to_nvfp4()Testing
test_nvfp4_transpose_kerneltest_nvfp4_partial_cast_matches_fulltest_single_gpu_partial_cast_vs_full_test_cast_master_weights_to_nvfp4This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:
https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads
Checklist: