Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ cmake_install.cmake
*.hip
*_hip.h
*_hip.cpp
!transformer_engine/pytorch/csrc/extensions/mxfp4_hip.cpp
*_hip.cuh
hip_driver*
hip_runtime*
Expand Down
5 changes: 4 additions & 1 deletion build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ def setup_pytorch_extension(
) -> setuptools.Extension:
"""Setup CUDA extension for PyTorch support"""

# Source files
# Source files - include both .cpp and .cu files
# .cu files will be hipified to .hip for ROCm builds
sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
cu_sources = all_files_in_dir(Path(csrc_source_files), name_extension="cu")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU code should live in the main TE not in extensions.

sources.extend(cu_sources)

# Header files
if rocm_build():
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ class MXFP4BlockScaling(Recipe):
fp4_format: Format = Format.E2M1
fp8_dpa: bool = False
fp8_mha: bool = False
use_hadamard: bool = os.getenv("NVTE_MXFP4_USE_HADAMARD", "0") == "1"

@property
def fp8_format(self) -> Format:
Expand Down
18 changes: 18 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ def general_gemm(
# Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]

# MXFP4 GEMM: route to AITER a4w4 ASM kernels
from ..tensor.storage.mxfp4_tensor_storage import MXFP4TensorStorage

if isinstance(A, MXFP4TensorStorage) or isinstance(B, MXFP4TensorStorage):
from ..module.fp4_handler_gemm import fp4_gemm_layout

result = fp4_gemm_layout(
A,
B,
layout=layout,
out_dtype=out_dtype if out_dtype is not None else torch.bfloat16,
bias=bias,
out=out,
grad=grad,
accumulate=accumulate,
)
return result, None, None, None

if isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage):
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorStorage GEMM
Expand Down
16 changes: 16 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,22 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<int> &split_sections,
std::vector<py::handle> quantizer_list);

/***************************************************************************************************
* MXFP4 Quantization
**************************************************************************************************/

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> cast_transpose_mxfp4_fused_shuffle(
at::Tensor input,
std::optional<at::Tensor> rowwise_fp4_out,
std::optional<at::Tensor> rowwise_scale_out,
std::optional<at::Tensor> colwise_fp4_out,
std::optional<at::Tensor> colwise_scale_out,
bool shuffle_rowwise_scale,
bool shuffle_colwise_scale,
bool shuffle_rowwise_fp4,
bool shuffle_colwise_fp4,
bool use_hadamard);

/***************************************************************************************************
* Bias gradient fusions
**************************************************************************************************/
Expand Down
Loading
Loading