Conversation
- Add mxfp4_hip.cpp: C++ wrapper for cast_transpose_mxfp4_fused_shuffle - Add cast_transpose_mxfp4_kernel_shuffled.cu: HIP kernel for fused MXFP4 cast, transpose, shuffle with optional Hadamard transform - Register MXFP4 types and cast_transpose_mxfp4_fused_shuffle in pybind - Declare function signature in extensions.h, type externs in pybind.h - Extend build_tools/pytorch.py to collect .cu files for hipify/hipcc - Unignore mxfp4_hip.cpp in .gitignore Made-with: Cursor
- Add use_hadamard param to MXFP4Quantizer (replaces global USE_HADAMARD env var) - Add use_hadamard field to MXFP4BlockScaling recipe, controlled by NVTE_MXFP4_USE_HADAMARD - Add fp4_handler_gemm.py: AITER a4w4 ASM kernel dispatch for FP4 GEMM - Remove manual kernel selection (_select_kernel); AITER handles CSV-based tuning - Route MXFP4TensorStorage inputs to fp4_handler_gemm in general_gemm() Made-with: Cursor
- Thread use_hadamard from MXFP4BlockScaling recipe into all MXFP4Quantizer instantiations (forward, backward, lazy columnwise) - Remove unused grad_output_quantizer_mxfp4 (stored but never read in backward) - Remove NVTE_MXFP4_DEBUG print blocks from forward/backward paths - Clean up quantizer tuple from 7 to 6 elements - Remove unused os import from linear.py Made-with: Cursor
| # Source files - include both .cpp and .cu files | ||
| # .cu files will be hipified to .hip for ROCm builds | ||
| sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp") | ||
| cu_sources = all_files_in_dir(Path(csrc_source_files), name_extension="cu") |
There was a problem hiding this comment.
GPU code should live in the main TE not in extensions.
wangye805
left a comment
There was a problem hiding this comment.
Below is the list of required unitests (cpp level gtest or pytorch pytests):
1). cpp level gtest for cast_transpose_mxfp4_shuffled. Can refer to our nvfp4 gtest (https://github.com/ROCm/TransformerEngine/blob/dev/tests/cpp/operator/test_cast_nvfp4_transpose.cu) as an example
2). pytorch pytest for aiter a4w4 gemm. You can refer to this fp8gemm in test_numerics as an example:
TransformerEngine/tests/pytorch/test_numerics.py
Line 3025 in 82617fe
3). extend the test_numerics with test_layernorm_linear_* (for example,
TransformerEngine/tests/pytorch/test_numerics.py
Line 1618 in 82617fe
TransformerEngine/tests/pytorch/test_numerics.py
Line 1302 in 82617fe
Introduces
cast_transpose_mxfp4_fused_shuffle. The existing Triton based MXFP4 Cast Transpose does not fused shuffle and hadamard. Replacedte_quantize_tritonpath inMXFP4Quantizer.update_quantized()with fully fused hip path. Using the Triton path for training would required an unfused hadamard call for stable trainingfp4_gemm_handlerthat dispatches a4w4 gemm calls to the AITER backend with the right layouts expeceted by AITER.Add MXFP4 weight caching in Linear.forward() and LayerNormLinear.forward() that persists quantized MXFP4TensorStorage weights across forward passes,
Training recipe can be enabled using
For LORA use HEALING_ITER=340