Skip to content

Full MXFP4 Training Recipe#506

Open
sarthak-amd wants to merge 3 commits intodevfrom
feature/mxfp4-recipe-210
Open

Full MXFP4 Training Recipe#506
sarthak-amd wants to merge 3 commits intodevfrom
feature/mxfp4-recipe-210

Conversation

@sarthak-amd
Copy link
Copy Markdown
Collaborator

@sarthak-amd sarthak-amd commented Mar 26, 2026

  1. Introduces cast_transpose_mxfp4_fused_shuffle . The existing Triton based MXFP4 Cast Transpose does not fused shuffle and hadamard. Replaced te_quantize_triton path in MXFP4Quantizer.update_quantized() with fully fused hip path. Using the Triton path for training would required an unfused hadamard call for stable training

  2. fp4_gemm_handler that dispatches a4w4 gemm calls to the AITER backend with the right layouts expeceted by AITER.

  3. Add MXFP4 weight caching in Linear.forward() and LayerNormLinear.forward() that persists quantized MXFP4TensorStorage weights across forward passes,

Training recipe can be enabled using

export FP4_RECIPE=mxfp4
export NVTE_MXFP4_USE_HADAMARD=1

For LORA use HEALING_ITER=340

- Add mxfp4_hip.cpp: C++ wrapper for cast_transpose_mxfp4_fused_shuffle
- Add cast_transpose_mxfp4_kernel_shuffled.cu: HIP kernel for fused MXFP4
  cast, transpose, shuffle with optional Hadamard transform
- Register MXFP4 types and cast_transpose_mxfp4_fused_shuffle in pybind
- Declare function signature in extensions.h, type externs in pybind.h
- Extend build_tools/pytorch.py to collect .cu files for hipify/hipcc
- Unignore mxfp4_hip.cpp in .gitignore

Made-with: Cursor
- Add use_hadamard param to MXFP4Quantizer (replaces global USE_HADAMARD env var)
- Add use_hadamard field to MXFP4BlockScaling recipe, controlled by NVTE_MXFP4_USE_HADAMARD
- Add fp4_handler_gemm.py: AITER a4w4 ASM kernel dispatch for FP4 GEMM
- Remove manual kernel selection (_select_kernel); AITER handles CSV-based tuning
- Route MXFP4TensorStorage inputs to fp4_handler_gemm in general_gemm()

Made-with: Cursor
- Thread use_hadamard from MXFP4BlockScaling recipe into all MXFP4Quantizer
  instantiations (forward, backward, lazy columnwise)
- Remove unused grad_output_quantizer_mxfp4 (stored but never read in backward)
- Remove NVTE_MXFP4_DEBUG print blocks from forward/backward paths
- Clean up quantizer tuple from 7 to 6 elements
- Remove unused os import from linear.py

Made-with: Cursor
@sarthak-amd sarthak-amd marked this pull request as ready for review March 26, 2026 01:49
# Source files - include both .cpp and .cu files
# .cu files will be hipified to .hip for ROCm builds
sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
cu_sources = all_files_in_dir(Path(csrc_source_files), name_extension="cu")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU code should live in the main TE not in extensions.

Copy link
Copy Markdown
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below is the list of required unitests (cpp level gtest or pytorch pytests):
1). cpp level gtest for cast_transpose_mxfp4_shuffled. Can refer to our nvfp4 gtest (https://github.com/ROCm/TransformerEngine/blob/dev/tests/cpp/operator/test_cast_nvfp4_transpose.cu) as an example
2). pytorch pytest for aiter a4w4 gemm. You can refer to this fp8gemm in test_numerics as an example:

def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_quantizer):

3). extend the test_numerics with test_layernorm_linear_* (for example,
def test_layernorm_linear_accuracy(
) and test_linear_*(
def test_linear_accuracy(dtype, bs, model, return_bias, bias):
) with mxfp4 format

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants