Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build_tools/hipify/custom_map.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1",
"__nv_fp4x2_storage_t" : "__hip_fp4x2_storage_t",
"#include <cudaTypedefs.h>" : "",
"#include <cuda/barrier>" : ""
"#include <cuda/barrier>" : "",
"#include <cuda_pipeline.h>" : ""
}
}

72 changes: 72 additions & 0 deletions tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -246,3 +248,73 @@ def test_nvfp4_quantization_noncontiguous_inputs(
use_cpp_allocator=use_cpp_allocator,
with_random_sign_mask=with_random_sign_mask,
)


def _ref_wht16_tiled(x: torch.Tensor, sign_mask: int) -> torch.Tensor:
"""Reference 16-point WHT tiled along last dim, normalised by 0.25."""
x = x.float()
_rows, cols = x.shape
d = torch.tensor(
[((-1) ** ((sign_mask >> i) & 1)) for i in range(16)],
dtype=torch.float32, device=x.device,
)
out = x.clone()
for c in range(0, cols, 16):
tile = out[:, c:c+16] * d # apply sign
h = 1
while h < 16:
for i in range(0, 16, h * 2):
a = tile[:, i:i+h].clone()
b = tile[:, i+h:i+2*h].clone()
tile[:, i:i+h] = a + b
tile[:, i+h:i+2*h] = a - b
h *= 2
out[:, c:c+16] = tile * 0.25
return out


@pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)])
def test_hadamard_transform_amax(rows, cols):
"""
Tests nvte_hadamard_transform_amax via NVFP4Quantizer (with_rht=True).
Exercises the WHT kernel without requiring a full NVFP4 recipe.
Checks:
- amax_rowwise == max|x| (pre-RHT amax of raw input)
- amax_colwise == max|WHT(x.T)| (post-RHT amax of transposed input)
"""
torch.manual_seed(42)
x = torch.randn((rows, cols), dtype=torch.bfloat16, device="cuda").contiguous()

quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
)
out = quantizer(x)

# amax_rowwise: pre-RHT, should equal max|x|
expected_rowwise_amax = x.float().abs().max()
torch.testing.assert_close(
out._amax_rowwise.float().squeeze(),
expected_rowwise_amax,
rtol=1e-3, atol=1e-3,
msg=f"pre-RHT amax mismatch rows={rows} cols={cols}",
)

# amax_colwise: post-RHT of x.T, should equal max|WHT(x.T)|
sign_mask_t = quantizer.rht_matrix_random_sign_mask_t
x_t = x.t().contiguous() # (cols, rows)
wht_x_t = _ref_wht16_tiled(x_t, sign_mask=sign_mask_t)
expected_colwise_amax = wht_x_t.float().abs().max()

torch.testing.assert_close(
out._amax_columnwise.float().squeeze().item(),
float(expected_colwise_amax),
rtol=2e-2, atol=2e-2,
msg=f"post-RHT amax mismatch rows={rows} cols={cols}",
)
2 changes: 1 addition & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
hadamard_transform/hadamard_transform.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)

if(USE_CUDA)
Expand All @@ -247,7 +248,6 @@ if(USE_CUDA)
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
transpose/quantize_transpose_square_blockwise.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
else()
#ROCm specific source codes
Expand Down
Loading
Loading