Skip to content

NVFP4 primary weight support#2691

Open
WanZzzzzz wants to merge 12 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_native_weights
Open

NVFP4 primary weight support#2691
WanZzzzzz wants to merge 12 commits intoNVIDIA:mainfrom
WanZzzzzz:fp4_native_weights

Conversation

@WanZzzzzz
Copy link

Description

This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:

NVFP4 Partial Cast Kernel (nvfp4_2d_partial_cast)

  • Implements nibble-accurate partial updates for NVFP4 tensors in distributed settings
  • Supports two-level NVFP4 scaling: global FP32 scale + per-block FP8 E4M3 scale

NVFP4 Transpose Kernel (nvfp4_transpose)

  • Custom transpose kernel for nibble-packed NVFP4 data with shared memory optimization
  • Uses vectorized uint2 loads/stores with 64×64 tiles for efficient memory access
  • Handles nibble repacking during transpose (unlike FP8 byte transpose)
  • Enables columnwise data generation for GEMM operations after rowwise AllGather

Fused Scale Kernel (nvfp4_fused_scale)

  • Fuses per-block scale computation, global amax copy, and FP8 scale expansion into a single kernel
  • Eliminates multiple kernel launches and avoids D2H transfers by accepting tensor pointers
  • Reduces kernel launch overhead in the critical path

Multi-Tensor Dispatch Pattern

  • C++-side loop dispatch for NVFP4 multi-tensor operations
  • Reduces Python–C++ transition overhead compared to per-tensor Python loops
  • Collects metadata in Python and executes batched operations in C++ wrappers

CPU Overhead Optimizations

  • Batched dtype conversion via torch.cat / torch.split
  • Replaced torch.zeros() with torch.empty() for immediately written buffers
  • Consolidated metadata collection and allocation phases
  • Optimized bucket partitioning for expert parallel buffers

Scale Computation Improvements

  • Fixed floating-point precision mismatch between Python and CUDA
  • Uses FP32 constants consistent with CUDA arithmetic
  • Ensures bitwise-identical results between partial and full quantization paths

New Public API

cast_master_weights_to_nvfp4()

  • Casts FP32 master weights to NVFP4 model weights
  • Handles global and per-block amax reduction across data parallel groups
  • Designed for low CPU overhead in distributed training loops

Testing

Test Description
test_nvfp4_transpose_kernel Verifies correctness for nibble-packed transpose
test_nvfp4_partial_cast_matches_full Multi-GPU: partial cast + all-gather equals full cast
test_single_gpu_partial_cast_vs_full Single-GPU: offset=0 partial cast matches reference quantizer
_test_cast_master_weights_to_nvfp4 500-iteration training loop with bitwise-identical loss

This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:

https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: qiyuw <qiyuw@nvidia.com>
@WanZzzzzz WanZzzzzz mentioned this pull request Feb 19, 2026
13 tasks
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

This PR implements NVFP4 partial cast infrastructure for distributed training with ZeRO/FSDP optimizers. The implementation enables efficient conversion of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks.

Key Changes

  • NVFP4 Partial Cast Kernel: Implements nibble-accurate partial updates for distributed settings with two-level scaling (global FP32 + per-block FP8 E4M3)
  • NVFP4 Transpose Kernel: Custom transpose for nibble-packed data with vectorized uint2 loads/stores and 64×64 tiling
  • Fused Scale Kernel: Combines per-block scale computation, global amax copy, and FP8 scale expansion into single kernel
  • Multi-Tensor Dispatch: C++-side loop dispatch reduces Python-C++ transition overhead
  • CPU Optimizations: Batched dtype conversion via torch.cat/torch.split, torch.empty() instead of torch.zeros() for buffers
  • New API: quantize_master_weights() function handles FP8 and NVFP4 quantization with distributed amax reduction

Issues from Previous Review

Previous review comments have been addressed:

  • Variable naming issues fixed
  • Missing test assertions added
  • Duplicate imports removed
  • Byte index calculation bug for odd start_offset values remains documented but not fixed in this iteration

Testing

Comprehensive test coverage including transpose correctness, multi-GPU partial cast validation, single-GPU equivalence testing, and 500-iteration training loop verification

Confidence Score: 3/5

  • Safe to merge with caution - byte index calculation issue may affect irregular sharding scenarios
  • Score reflects known byte index calculation bug with odd offsets (documented in previous review), though current tests use even boundaries. Core functionality is well-tested for standard ZeRO/FSDP use cases.
  • Pay close attention to transformer_engine/common/recipe/nvfp4.cu - byte index calculation on line 236 has documented issue with odd start_offset values that could affect expert parallelism or irregular sharding patterns

Important Files Changed

Filename Overview
transformer_engine/common/recipe/nvfp4.cu Implements NVFP4 partial cast, transpose, and scale kernels with multi-tensor optimizations; byte index calculation issue with odd offsets noted in previous review
transformer_engine/pytorch/tensor/utils.py Adds quantize_master_weights with NVFP4 support, batched dtype conversion, and multi-tensor dispatch for reduced CPU overhead; previous issues with variable naming and docstrings addressed
transformer_engine/pytorch/csrc/extensions/nvfp4_2d_partial_cast.cpp Adds C++ bindings for NVFP4 partial cast and amax computation with multi-tensor batch processing loop
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py Adds comprehensive NVFP4 tests including partial cast, transpose, and distributed training validation; missing assertion issues from previous review now fixed

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[FP32 Master Weight Shards] --> B[Batch Convert to BF16/FP16]
    B --> C[Compute Partial Amax per Block]
    C --> D[Compute Global Amax per Param]
    D --> E[AllReduce: Max Across DP Ranks]
    E --> F[Compute Global Scale]
    F --> G[Fused Scale Kernel]
    G --> H[Compute Per-Block Decode Scale]
    G --> I[Expand to Row-Level FP8 E4M3]
    G --> J[Copy Global Amax to Target]
    H --> K[Partial Cast: FP32 to NVFP4]
    I --> K
    F --> K
    K --> L[NVFP4 Rowwise Data Nibble-Packed]
    L --> M{AllGather Needed?}
    M -->|Yes| N[AllGather Model Weights]
    M -->|No| O[Post Processing]
    N --> O
    O --> P[NVFP4 Data Transpose]
    O --> Q[NVFP4 Scale Transpose]
    P --> R[NVFP4 Columnwise Data]
    Q --> S[NVFP4 Columnwise Scale]
    R --> T[Ready for GEMM]
    S --> T
Loading

Last reviewed commit: c29d523

greptile-apps[bot]

This comment was marked as resolved.

@timmoon10

This comment was marked as outdated.

start_offsets,
group,
fsdp_shard_model_weights=None,
manual_post_all_gather_processing=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added this kwarg to the FP8 functions for backward compatibility, but there's no point keeping them for these brand-new NVFP4 APIs:

Suggested change
manual_post_all_gather_processing=False,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fsdp_shard_model_weights=None is for future FSDP support. It's in the plan.
manual_post_all_gather_processing is also needed for the same reason as FP8 blockwise scaling:
https://github.com/WanZzzzzz/TransformerEngine/blob/38b92b1a168dcfaa6242fea50f03e5a1b873e3a0/transformer_engine/pytorch/tensor/utils.py#L535

Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense for now then. Let's change the default to True though since that's preferred.

I want to flag a potential future problem with manual_post_all_gather_processing=False: it assumes that the quantized tensor has some way to handle the post-processing automatically. For FP8 on Hopper:

cast_master_weights_to_fp8(..., manual_post_all_gather_processing=False)
torch.all_gather(...)

y = model(x)  # Float8Tensor internally performs FP8 transpose

This is not something TE will guarantee for future data formats. Maybe the next recipe has some interleaved format:

cast_master_weights_to_futureformat(...)
torch.all_gather(...)
fix_futureformat_interleaving(...)

y = model(x)  # FutureFormatTensor assumes data is interleaved

In this case, we should throw an error with the user passes manual_post_all_gather_processing=False and it should be Mcore's responsibility to perform the post-processing in a way that's friendly to overlapping.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, note it down.

Comment on lines 245 to 259
if isinstance(self.weights[0], QuantizedTensor):
weight_buffer_dtype = torch.uint8
if self.weights_are_nvfp4:
weight_buffer_length = self.storage_total
buffer_rank_start = storage_rank_start
buffer_rank_end = storage_rank_end
else:
weight_buffer_length = self.offsets[-1]
buffer_rank_start = rank_start
buffer_rank_end = rank_end
else:
weight_buffer_dtype = weights[0].dtype
weight_buffer_length = self.offsets[-1]
buffer_rank_start = rank_start
buffer_rank_end = rank_end
Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It's a bit convoluted, isn't it? It would be much nicer to disentangle the quantization logic from the buffer allocation by computing storage offsets in all cases (even if it's trivial for non-NVFP4 cases) and then using that blindly here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

qiyuw and others added 2 commits February 20, 2026 05:52
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

@greptile-apps

This comment was marked as resolved.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: qiyuw <qiyuw@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@timmoon10
Copy link
Collaborator

/te-ci L1

timmoon10
timmoon10 previously approved these changes Feb 21, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, although there are some test failures related to missing licenses and linter warnings. I also still have some nits, although they are not blocking.

@timmoon10 timmoon10 self-requested a review February 21, 2026 00:09
@timmoon10 timmoon10 dismissed their stale review February 21, 2026 00:09

Test failures

Signed-off-by: qiyuw <qiyuw@nvidia.com>
Signed-off-by: qiyuw <qiyuw@nvidia.com>
qiyuw and others added 2 commits February 26, 2026 22:13
Signed-off-by: qiyuw <qiyuw@nvidia.com>
continue;
}
const size_t ref_idx = mask[first] ? elem_idx[first] : elem_idx[second];
const size_t byte_idx = (ref_idx - start_offset) >> 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Byte index calculation incorrect for odd start_offset values

When start_offset is odd, this formula produces wrong byte indices. For example:

  • Element at index 2 with start_offset=1: (2-1)>>1 = 0 (wrong, should be 1)
  • Element at index 3 with start_offset=1: (3-1)>>1 = 1 (correct)

Should be: byte_idx = (ref_idx >> 1) - (start_offset >> 1)

Current tests only use even start_offset values (multiples of shard_size), so this bug isn't caught. While typical ZeRO/FSDP sharding uses even boundaries, irregular sharding (e.g., expert parallelism) could trigger this.

@timmoon10

This comment was marked as outdated.

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@timmoon10 timmoon10 requested a review from ksivaman March 2, 2026 02:08
@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants