From 80e0aab6cab6c1cefbc1b0e5abfc5fd0aa6dcbbd Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Mar 2026 13:58:07 -0500 Subject: [PATCH 1/4] initial impl --- build_tools/hipify/custom_map.json | 3 +- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 66 ++++ transformer_engine/common/CMakeLists.txt | 2 +- .../hadamard_transform/hadamard_transform.cu | 318 ++++++++++++++++++ .../transformer_engine/hadamard_transform.h | 4 - transformer_engine/pytorch/csrc/common.h | 2 - transformer_engine/pytorch/csrc/pybind.h | 4 - transformer_engine/pytorch/csrc/quantizer.cpp | 8 +- 8 files changed, 393 insertions(+), 14 deletions(-) diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 872d38efa..92b3f0a44 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -15,7 +15,8 @@ "__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1", "__nv_fp4x2_storage_t" : "__hip_fp4x2_storage_t", "#include " : "", - "#include " : "" + "#include " : "", + "#include " : "" } } diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 98be9a4f5..b3b5ad8df 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -246,3 +246,69 @@ def test_nvfp4_quantization_noncontiguous_inputs( use_cpp_allocator=use_cpp_allocator, with_random_sign_mask=with_random_sign_mask, ) + + +def _ref_wht16_tiled(x: torch.Tensor, sign_mask: int) -> torch.Tensor: + """Pure-Python reference WHT: tiled 16-point butterfly, normalised by 0.25.""" + import numpy as np + x_np = x.float().cpu().numpy().copy() + rows, cols = x_np.shape + d = np.array([((-1) ** ((sign_mask >> i) & 1)) for i in range(16)], dtype=np.float32) + for c in range(0, cols, 16): + tile = x_np[:, c:c+16] * d + h = 1 + while h < 16: + for i in range(0, 16, h * 2): + for j in range(i, i + h): + a, b = tile[:, j].copy(), tile[:, j + h].copy() + tile[:, j], tile[:, j + h] = a + b, a - b + h *= 2 + x_np[:, c:c+16] = tile * 0.25 + return torch.from_numpy(x_np) + + +@pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)]) +def test_hadamard_transform_amax(rows, cols): + """ + Tests nvte_hadamard_transform_amax via NVFP4Quantizer (with_rht=True). + Exercises the WHT kernel without requiring a full NVFP4 recipe. + Checks: + - amax_rowwise == max|x| (pre-RHT amax of raw input) + - amax_colwise == max|WHT(x.T)| (post-RHT amax of transposed input) + """ + torch.manual_seed(42) + x = torch.randn((rows, cols), dtype=torch.bfloat16, device="cuda").contiguous() + + quantizer = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + ) + out = quantizer(x) + + # amax_rowwise: pre-RHT, should equal max|x| + expected_rowwise_amax = x.float().abs().max() + torch.testing.assert_close( + out._amax_rowwise.float().squeeze(), + expected_rowwise_amax, + rtol=1e-3, atol=1e-3, + msg=f"pre-RHT amax mismatch rows={rows} cols={cols}", + ) + + # amax_colwise: post-RHT of x.T, should equal max|WHT(x.T)| + sign_mask_t = quantizer.rht_matrix_random_sign_mask_t + x_t = x.t().contiguous() # (cols, rows) + wht_x_t = _ref_wht16_tiled(x_t, sign_mask=sign_mask_t) + expected_colwise_amax = wht_x_t.float().abs().max() + + torch.testing.assert_close( + out._amax_columnwise.float().squeeze().item(), + float(expected_colwise_amax), + rtol=2e-2, atol=2e-2, + msg=f"post-RHT amax mismatch rows={rows} cols={cols}", + ) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8d5537368..847fbcb8e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -223,6 +223,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources activation/gelu.cu activation/relu.cu activation/swiglu.cu + hadamard_transform/hadamard_transform.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) if(USE_CUDA) @@ -247,7 +248,6 @@ if(USE_CUDA) list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu transpose/quantize_transpose_square_blockwise.cu - hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu) else() #ROCm specific source codes diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index b901f9023..e2f449673 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -23,6 +23,8 @@ namespace { constexpr int kThreadsPerWarp = 32; constexpr float k16x16HadamardScale = 0.25f; +#ifndef __HIP_PLATFORM_AMD__ + template __device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, @@ -658,12 +660,261 @@ __global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restri #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 } +#endif // __HIP_PLATFORM_AMD__ } // namespace +#ifdef __HIP_PLATFORM_AMD__ + +namespace { + +static constexpr int kHadamardDim = 16; +static constexpr int kWarpSize = 64; +static constexpr int kThreadsPerWHT = 4; +static constexpr int kElemsPerThread = 4; +static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16 +static constexpr int kWarpsPerBlock = 4; +static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 +static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 +static constexpr float kHadamardScale = 0.25f; + +// ds_swizzle: sub-wavefront exchange without LDS. +// Same instructions as cast_transpose_mxfp4_kernel_shuffled.cu. +__device__ __forceinline__ float ds_swizzle_xor1(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +__device__ __forceinline__ float ds_swizzle_xor2(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +// BF16 helpers +__device__ __forceinline__ float to_f32 (__hip_bfloat16 v) { return static_cast(v); } +__device__ __forceinline__ __hip_bfloat16 to_bf16(float v) { return static_cast<__hip_bfloat16>(v); } + +// Bit-cast __hip_bfloat16->uint16_t without address-of-temporary. +__device__ __forceinline__ uint16_t bf16_to_bits(__hip_bfloat16 v) { + uint16_t bits; __builtin_memcpy(&bits, &v, sizeof(uint16_t)); return bits; +} + +// Unpack/pack 4 BF16 values as uint64_t (vectorised global load/store). +// Same trick as cast_transpose_mxfp4_kernel_shuffled.cu::bf16x4_to_float4. +__device__ __forceinline__ void unpack_bf16x4(uint64_t p, + float& v0, float& v1, float& v2, float& v3) { + v0 = __uint_as_float(((uint32_t)( p & 0xFFFF)) << 16); + v1 = __uint_as_float(((uint32_t)((p >> 16) & 0xFFFF)) << 16); + v2 = __uint_as_float(((uint32_t)((p >> 32) & 0xFFFF)) << 16); + v3 = __uint_as_float(((uint32_t)((p >> 48) & 0xFFFF)) << 16); +} + +__device__ __forceinline__ uint64_t pack_bf16x4(float v0, float v1, float v2, float v3) { + return (uint64_t)bf16_to_bits(to_bf16(v0)) + | ((uint64_t)bf16_to_bits(to_bf16(v1)) << 16) + | ((uint64_t)bf16_to_bits(to_bf16(v2)) << 32) + | ((uint64_t)bf16_to_bits(to_bf16(v3)) << 48); +} + +// 16-point WHT: in-register, no shared memory. +// Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace, +// extended with NV random_sign_mask (uint16_t bitmask). +// thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3). +// apply_pre=true -> D before WHT (forward); false -> D after WHT (inverse). +__device__ __forceinline__ void wht16( + float& v0, float& v1, float& v2, float& v3, + int thread_in_group, uint16_t sign_mask, bool apply_pre) { + auto sgn = [&](int k) -> float { + return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f; + }; + if (apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } + + // Stage 1: local H4 + float a0=v0+v1, a1=v0-v1, a2=v2+v3, a3=v2-v3; + v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3; + + // Stage 2: cross-thread XOR-1 + { float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1), + p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3); + bool up=(thread_in_group&1); + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + // Stage 3: cross-thread XOR-2 + { float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1), + p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3); + bool up=(thread_in_group>>1)&1; + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale; + if (!apply_pre) { v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); } +} + +// Grid: blockIdx.x = col tile [0, row_length/16) +// blockIdx.y = row batch [0, ceil(num_rows/64)) +// Block: 256 threads = 4 wavefronts of 64 lanes. +// lane/4 = row_in_warp (0..15), lane%4 = thread_in_grp (0..3) +template +__global__ __launch_bounds__(kThreadsPerBlock, 4) +void HadamardTransformKernel( + const __hip_bfloat16* __restrict__ input, + __hip_bfloat16* __restrict__ output, + __hip_bfloat16* __restrict__ output_t, + uint16_t random_sign_mask, uint16_t random_sign_mask_t, + uint64_t num_rows, uint64_t row_length, + float* __restrict__ amax, float* __restrict__ amax_t, + bool inverse_hadamard) { + const int tid = threadIdx.x; + const int warp_id = tid / kWarpSize; + const int lane_id = tid % kWarpSize; + const int row_in_warp = lane_id / kThreadsPerWHT; + const int thread_in_grp = lane_id % kThreadsPerWHT; + const uint64_t col_tile_base = (uint64_t)blockIdx.x * kHadamardDim; + const uint64_t row_batch = (uint64_t)blockIdx.y * kRowsPerBlock; + const uint64_t global_row = row_batch + (uint64_t)warp_id*kRowsPerWarp + row_in_warp; + const uint64_t col_base = col_tile_base + (uint64_t)thread_in_grp * kElemsPerThread; + + const bool apply_pre = !inverse_hadamard; + const bool in_bounds = (global_row < num_rows) && (col_base + kElemsPerThread - 1 < row_length); + + // Smem for transposed path: 64×(16+1) BF16; +1 avoids LDS bank conflict. + __shared__ __hip_bfloat16 smem[kRowsPerBlock][kHadamardDim + 1]; + float v0=0.f, v1=0.f, v2=0.f, v3=0.f; + if (in_bounds) { + unpack_bf16x4(*reinterpret_cast( + &input[global_row * row_length + col_base]), v0, v1, v2, v3); + } + + // Identity path: WHT along row dimension + if constexpr (kComputeIdentity || kUpdateAmax) { + float r0=v0, r1=v1, r2=v2, r3=v3; + if (global_row < num_rows) { + wht16(r0, r1, r2, r3, thread_in_grp, random_sign_mask, apply_pre); + if constexpr (kUpdateAmax) { + float lam = fmaxf(fmaxf(fabsf(r0),fabsf(r1)),fmaxf(fabsf(r2),fabsf(r3))); + for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); + if (lane_id == 0) atomicMaxFloat(amax, lam); + } + if constexpr (kComputeIdentity) + if (output && in_bounds) + *reinterpret_cast(&output[global_row*row_length+col_base]) = + pack_bf16x4(r0,r1,r2,r3); + } + } + + // Transposed path: WHT along column dimension via smem transpose + if constexpr (kComputeTransposed || kUpdateAmaxT) { + const int local_row = warp_id * kRowsPerWarp + row_in_warp; + const int col_offset = thread_in_grp * kElemsPerThread; + smem[local_row][col_offset+0] = to_bf16(global_row < num_rows ? v0 : 0.f); + smem[local_row][col_offset+1] = to_bf16(global_row < num_rows ? v1 : 0.f); + smem[local_row][col_offset+2] = to_bf16(global_row < num_rows ? v2 : 0.f); + smem[local_row][col_offset+3] = to_bf16(global_row < num_rows ? v3 : 0.f); + __syncthreads(); + + // Re-read: row_in_warp -> column index, thread_in_grp -> 4 rows + const int t_col = row_in_warp; + const int smem_rbase = warp_id*kRowsPerWarp + thread_in_grp*kElemsPerThread; + + float c0=to_f32(smem[smem_rbase+0][t_col]), c1=to_f32(smem[smem_rbase+1][t_col]); + float c2=to_f32(smem[smem_rbase+2][t_col]), c3=to_f32(smem[smem_rbase+3][t_col]); + + wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, apply_pre); + + if constexpr (kUpdateAmaxT) { + float lam = fmaxf(fmaxf(fabsf(c0),fabsf(c1)),fmaxf(fabsf(c2),fabsf(c3))); + for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); + if (lane_id == 0) atomicMaxFloat(amax_t, lam); + } + + if constexpr (kComputeTransposed) { + if (output_t) { + const uint64_t global_col = col_tile_base + t_col; + const uint64_t out_row_base = row_batch + (uint64_t)warp_id*kRowsPerWarp + + (uint64_t)thread_in_grp*kElemsPerThread; + if (global_col < row_length && out_row_base+kElemsPerThread-1 < num_rows) + *reinterpret_cast( + &output_t[global_col*num_rows+out_row_base]) = + pack_bf16x4(c0,c1,c2,c3); + } + } + } +} + +// Pre-RHT amax: max|input| before any transform. +__global__ void PreRhtAmaxKernel(const __hip_bfloat16* __restrict__ input, + float* __restrict__ amax_out, uint64_t num_elems) { + float lam = 0.f; + for (uint64_t i = (uint64_t)blockIdx.x*blockDim.x+threadIdx.x; + i < num_elems; i += (uint64_t)gridDim.x*blockDim.x) + lam = fmaxf(lam, fabsf(to_f32(input[i]))); + for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); + if (threadIdx.x % kWarpSize == 0) atomicMaxFloat(amax_out, lam); +} + +static inline dim3 transform_grid(uint64_t num_rows, uint64_t row_length) { + return dim3((uint32_t)(row_length / kHadamardDim), + (uint32_t)DIVUP(num_rows, (uint64_t)kRowsPerBlock)); +} + +} // namespace + +#endif // __HIP_PLATFORM_AMD__ + void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform); +#ifdef __HIP_PLATFORM_AMD__ + NVTE_CHECK(input_.dtype() == DType::kBFloat16, "Input must be BF16."); + NVTE_CHECK(input_.dim() >= 2, "Input must be >=2D."); + + const SimpleTensor& input = input_.data; + SimpleTensor identity_out; + SimpleTensor& transposed_out = output_.data; + + const bool want_identity = (identity_out.dptr != nullptr); + const bool want_transposed = (transposed_out.dptr != nullptr); + + if (!want_identity && !want_transposed) + return; + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + + for (size_t i = 0; i < ndim - 1; ++i) + num_rows *= input.shape[i]; + NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); + NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); + + auto* in_ptr = reinterpret_cast(input.dptr); + auto* id_ptr = reinterpret_cast<__hip_bfloat16*>(identity_out.dptr); + auto* tr_ptr = reinterpret_cast<__hip_bfloat16*>(transposed_out.dptr); + dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); + +#define LAUNCH_T(IDENT, TRANS) \ + HadamardTransformKernel \ + <<>>(in_ptr,id_ptr,tr_ptr, \ + random_sign_mask,random_sign_mask_t, \ + (uint64_t)num_rows,(uint64_t)row_length,nullptr,nullptr,false) + + if (want_identity && want_transposed) + LAUNCH_T(true, true); + else if (want_identity) + LAUNCH_T(true, false); + else + LAUNCH_T(false, true); + NVTE_CHECK_CUDA(cudaGetLastError()); +#undef LAUNCH_T +#else // CUDA // Check tensors // NOTE (frsun): This is non-intuitive, we are writing the result of // transposed RHT to the output of rowwise. @@ -736,6 +987,7 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s num_rows, row_length, nullptr, nullptr, false););); NVTE_CHECK_CUDA(cudaGetLastError()); +#endif // __HIP_PLATFORM_AMD__ } // Kernel that will apply the 16x16 hadamard transform the input and input.T, and then @@ -743,6 +995,71 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform_amax); +#ifdef __HIP_PLATFORM_AMD__ + NVTE_CHECK(input_.dtype() == DType::kBFloat16, "Input must be BF16."); + NVTE_CHECK(input_.dim() >= 2, "Input must be >=2D."); + + const SimpleTensor& input = input_.data; + SimpleTensor& pre_rht_tensor = output_.amax; + SimpleTensor identity_tensor; + SimpleTensor& transpose_tensor = output_.columnwise_amax; + + const bool want_pre_rht = (pre_rht_tensor.dptr != nullptr); + const bool want_identity = (identity_tensor.dptr != nullptr); + const bool want_trans = (transpose_tensor.dptr != nullptr); + + if (!want_pre_rht && !want_identity && !want_trans) + return; + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + + for (size_t i = 0; i < ndim - 1; ++i) + num_rows *= input.shape[i]; + + NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); + NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); + + auto* in_ptr = reinterpret_cast(input.dptr); + auto* pre_amax_ptr = reinterpret_cast(pre_rht_tensor.dptr); + auto* id_amax_ptr = reinterpret_cast(identity_tensor.dptr); + auto* tr_amax_ptr = reinterpret_cast(transpose_tensor.dptr); + + if (pre_amax_ptr) + NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); + if (id_amax_ptr) + NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); + if (tr_amax_ptr) + NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); + + if (want_pre_rht) { + const uint64_t num_elems = (uint64_t)num_rows * row_length; + dim3 g(DIVUP(num_elems, (uint64_t)kThreadsPerBlock)); + PreRhtAmaxKernel<<>>(in_ptr,pre_amax_ptr,num_elems); + NVTE_CHECK_CUDA(cudaGetLastError()); + } + + if (want_identity || want_trans) { + dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); +#define LAUNCH_A(IDENT,TRANS,UA,UAT) \ + HadamardTransformKernel \ + <<>>(in_ptr,nullptr,nullptr, \ + random_sign_mask,random_sign_mask_t, \ + (uint64_t)num_rows,(uint64_t)row_length, \ + id_amax_ptr,tr_amax_ptr,false) + + if (want_identity && want_trans) + LAUNCH_A(true, true, true, true); + else if (want_identity) + LAUNCH_A(true, false, true, false); + else + LAUNCH_A(false,true, false, true); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#undef LAUNCH_A + } +#else // __HIP_PLATFORM_AMD__ #if CUDA_VERSION >= 12080 // Check input tensor @@ -853,6 +1170,7 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // CUDA_VERSION >= 12080 +#endif // __HIP_PLATFORM_AMD__ } } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 6785040df..90a722e66 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -13,8 +13,6 @@ #ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ #define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ -#ifndef __HIP_PLATFORM_AMD__ - #include "transformer_engine.h" #ifdef __cplusplus @@ -69,6 +67,4 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE } // extern "C" #endif -#endif - #endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 27f24a961..6c19bae13 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -293,7 +293,6 @@ class MXFP8Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; -#ifndef USE_ROCM class NVFP4Quantizer : public Quantizer { public: // fp4 dtype @@ -347,7 +346,6 @@ class NVFP4Quantizer : public Quantizer { void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; -#endif // #ifndef USE_ROCM std::unique_ptr convert_quantizer(py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index d5fd4a4fe..b924e8a77 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -110,12 +110,8 @@ constexpr std::array custom_types_converters = { CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), -#ifdef USE_ROCM -}; -#else std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, CreateQuantizer)}; -#endif } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 50c6bc810..7ca0ab3bf 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1163,7 +1163,6 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s #endif } -#ifndef USE_ROCM NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->with_rht = quantizer.attr("with_rht").cast(); @@ -1516,6 +1515,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; +#ifdef USE_ROCM + eligible_for_rht_cast_fusion = false; +#endif + // Compute amax. if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { @@ -1663,6 +1666,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream); }); } else { +#ifndef USE_ROCM // RHT cast fusion kernel. NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, "RHT matrix is not set"); @@ -1671,6 +1675,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou nvte_hadamard_transform_cast_fusion_columnwise( input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream); }); +#endif } } } else { @@ -1740,6 +1745,5 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s } return scale_shape; } -#endif } // namespace transformer_engine::pytorch From bda7b13b9e1ed4af99457c83bfc13ac51cede85e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 30 Mar 2026 12:12:10 -0500 Subject: [PATCH 2/4] test update --- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 28 +++++++++++-------- .../hadamard_transform/hadamard_transform.cu | 2 ++ 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index b3b5ad8df..5fd55e059 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -249,22 +251,26 @@ def test_nvfp4_quantization_noncontiguous_inputs( def _ref_wht16_tiled(x: torch.Tensor, sign_mask: int) -> torch.Tensor: - """Pure-Python reference WHT: tiled 16-point butterfly, normalised by 0.25.""" - import numpy as np - x_np = x.float().cpu().numpy().copy() - rows, cols = x_np.shape - d = np.array([((-1) ** ((sign_mask >> i) & 1)) for i in range(16)], dtype=np.float32) + """Reference 16-point WHT tiled along last dim, normalised by 0.25.""" + x = x.float() + _rows, cols = x.shape + d = torch.tensor( + [((-1) ** ((sign_mask >> i) & 1)) for i in range(16)], + dtype=torch.float32, device=x.device, + ) + out = x.clone() for c in range(0, cols, 16): - tile = x_np[:, c:c+16] * d + tile = out[:, c:c+16] * d # apply sign h = 1 while h < 16: for i in range(0, 16, h * 2): - for j in range(i, i + h): - a, b = tile[:, j].copy(), tile[:, j + h].copy() - tile[:, j], tile[:, j + h] = a + b, a - b + a = tile[:, i:i+h].clone() + b = tile[:, i+h:i+2*h].clone() + tile[:, i:i+h] = a + b + tile[:, i+h:i+2*h] = a - b h *= 2 - x_np[:, c:c+16] = tile * 0.25 - return torch.from_numpy(x_np) + out[:, c:c+16] = tile * 0.25 + return out @pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)]) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index e2f449673..64aa243e1 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. From 63c7a48730f6ba44aae9c831ad5be4a1f8a82f11 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 30 Mar 2026 15:42:04 -0500 Subject: [PATCH 3/4] amax opt --- .../hadamard_transform/hadamard_transform.cu | 100 ++++++++++++++---- 1 file changed, 82 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 64aa243e1..3c8a7e53d 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -732,6 +732,7 @@ __device__ __forceinline__ void wht16( auto sgn = [&](int k) -> float { return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f; }; + if (apply_pre) { v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); } @@ -755,7 +756,10 @@ __device__ __forceinline__ void wht16( v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale; - if (!apply_pre) { v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); } + + if (!apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } } // Grid: blockIdx.x = col tile [0, row_length/16) @@ -786,8 +790,12 @@ void HadamardTransformKernel( const bool apply_pre = !inverse_hadamard; const bool in_bounds = (global_row < num_rows) && (col_base + kElemsPerThread - 1 < row_length); - // Smem for transposed path: 64×(16+1) BF16; +1 avoids LDS bank conflict. + // Smem for transposed path: 64*(16+1) BF16; +1 avoids LDS bank conflict. __shared__ __hip_bfloat16 smem[kRowsPerBlock][kHadamardDim + 1]; + + __shared__ float block_amax[kWarpsPerBlock]; + __shared__ float block_amax_t[kWarpsPerBlock]; + float v0=0.f, v1=0.f, v2=0.f, v3=0.f; if (in_bounds) { unpack_bf16x4(*reinterpret_cast( @@ -797,24 +805,30 @@ void HadamardTransformKernel( // Identity path: WHT along row dimension if constexpr (kComputeIdentity || kUpdateAmax) { float r0=v0, r1=v1, r2=v2, r3=v3; + float lam = 0.f; if (global_row < num_rows) { wht16(r0, r1, r2, r3, thread_in_grp, random_sign_mask, apply_pre); if constexpr (kUpdateAmax) { - float lam = fmaxf(fmaxf(fabsf(r0),fabsf(r1)),fmaxf(fabsf(r2),fabsf(r3))); - for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); - if (lane_id == 0) atomicMaxFloat(amax, lam); + lam = fmaxf(fmaxf(fabsf(r0),fabsf(r1)),fmaxf(fabsf(r2),fabsf(r3))); + for (int off=kWarpSize/2; off>=1; off>>=1) + lam=fmaxf(lam,__shfl_xor(lam,off)); } if constexpr (kComputeIdentity) if (output && in_bounds) *reinterpret_cast(&output[global_row*row_length+col_base]) = pack_bf16x4(r0,r1,r2,r3); } + if constexpr (kUpdateAmax) { + if (lane_id == 0) + block_amax[warp_id] = lam; + } } // Transposed path: WHT along column dimension via smem transpose if constexpr (kComputeTransposed || kUpdateAmaxT) { const int local_row = warp_id * kRowsPerWarp + row_in_warp; const int col_offset = thread_in_grp * kElemsPerThread; + float lam = 0.f; smem[local_row][col_offset+0] = to_bf16(global_row < num_rows ? v0 : 0.f); smem[local_row][col_offset+1] = to_bf16(global_row < num_rows ? v1 : 0.f); smem[local_row][col_offset+2] = to_bf16(global_row < num_rows ? v2 : 0.f); @@ -831,9 +845,10 @@ void HadamardTransformKernel( wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, apply_pre); if constexpr (kUpdateAmaxT) { - float lam = fmaxf(fmaxf(fabsf(c0),fabsf(c1)),fmaxf(fabsf(c2),fabsf(c3))); - for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); - if (lane_id == 0) atomicMaxFloat(amax_t, lam); + lam = fmaxf(fmaxf(fabsf(c0),fabsf(c1)),fmaxf(fabsf(c2),fabsf(c3))); + + for (int off=kWarpSize/2; off>=1; off>>=1) + lam=fmaxf(lam,__shfl_xor(lam,off)); } if constexpr (kComputeTransposed) { @@ -847,18 +862,67 @@ void HadamardTransformKernel( pack_bf16x4(c0,c1,c2,c3); } } + + if constexpr (kUpdateAmaxT) { + if (lane_id == 0) + block_amax_t[warp_id] = lam; + } + } + + if constexpr (kUpdateAmax || kUpdateAmaxT) { + __syncthreads(); + + if (warp_id == 0) { + if constexpr (kUpdateAmax) { + float block_lam = (lane_id < kWarpsPerBlock) ? block_amax[lane_id] : 0.f; + + for (int off=kWarpSize/2; off>=1; off>>=1) + block_lam = fmaxf(block_lam, __shfl_xor(block_lam, off)); + + if (lane_id == 0) + atomicMaxFloat(amax, block_lam); + } + + if constexpr (kUpdateAmaxT) { + float block_lam_t = (lane_id < kWarpsPerBlock) ? block_amax_t[lane_id] : 0.f; + + for (int off=kWarpSize/2; off>=1; off>>=1) + block_lam_t = fmaxf(block_lam_t, __shfl_xor(block_lam_t, off)); + + if (lane_id == 0) + atomicMaxFloat(amax_t, block_lam_t); + } + } } } // Pre-RHT amax: max|input| before any transform. __global__ void PreRhtAmaxKernel(const __hip_bfloat16* __restrict__ input, float* __restrict__ amax_out, uint64_t num_elems) { + __shared__ float block_amax[kWarpsPerBlock]; float lam = 0.f; for (uint64_t i = (uint64_t)blockIdx.x*blockDim.x+threadIdx.x; i < num_elems; i += (uint64_t)gridDim.x*blockDim.x) - lam = fmaxf(lam, fabsf(to_f32(input[i]))); - for (int off=kWarpSize/2; off>=1; off>>=1) lam=fmaxf(lam,__shfl_xor(lam,off)); - if (threadIdx.x % kWarpSize == 0) atomicMaxFloat(amax_out, lam); + lam = fmaxf(lam, fabsf(to_f32(input[i]))); + + for (int off=kWarpSize/2; off>=1; off>>=1) + lam=fmaxf(lam,__shfl_xor(lam,off)); + + const int warp_id = threadIdx.x / kWarpSize; + const int lane_id = threadIdx.x % kWarpSize; + if (lane_id == 0) + block_amax[warp_id] = lam; + + __syncthreads(); + + if (warp_id == 0) { + float block_lam = (lane_id < kWarpsPerBlock) ? block_amax[lane_id] : 0.f; + for (int off=kWarpSize/2; off>=1; off>>=1) + block_lam=fmaxf(block_lam,__shfl_xor(block_lam,off)); + + if (lane_id == 0) + atomicMaxFloat(amax_out, block_lam); + } } static inline dim3 transform_grid(uint64_t num_rows, uint64_t row_length) { @@ -878,7 +942,7 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s NVTE_CHECK(input_.dim() >= 2, "Input must be >=2D."); const SimpleTensor& input = input_.data; - SimpleTensor identity_out; + SimpleTensor identity_out; // Unused SimpleTensor& transposed_out = output_.data; const bool want_identity = (identity_out.dptr != nullptr); @@ -904,9 +968,9 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s #define LAUNCH_T(IDENT, TRANS) \ HadamardTransformKernel \ - <<>>(in_ptr,id_ptr,tr_ptr, \ - random_sign_mask,random_sign_mask_t, \ - (uint64_t)num_rows,(uint64_t)row_length,nullptr,nullptr,false) + <<>>(in_ptr, id_ptr, tr_ptr, \ + random_sign_mask, random_sign_mask_t, \ + (uint64_t)num_rows, (uint64_t)row_length, nullptr, nullptr, false) if (want_identity && want_transposed) LAUNCH_T(true, true); @@ -992,8 +1056,8 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s #endif // __HIP_PLATFORM_AMD__ } -// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then -// get the absolute max value of the result. +// Kernel that applies the 16x16 hadamard transform the input and input.T, and then +// gets the absolute max value of the result. void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform_amax); @@ -1003,7 +1067,7 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran const SimpleTensor& input = input_.data; SimpleTensor& pre_rht_tensor = output_.amax; - SimpleTensor identity_tensor; + SimpleTensor identity_tensor; // Unused SimpleTensor& transpose_tensor = output_.columnwise_amax; const bool want_pre_rht = (pre_rht_tensor.dptr != nullptr); From a2604591481f0417a14bb42ae7d1ee867a6fd076 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 30 Mar 2026 17:07:54 -0500 Subject: [PATCH 4/4] simplify --- .../hadamard_transform/hadamard_transform.cu | 133 +++++++----------- 1 file changed, 53 insertions(+), 80 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 3c8a7e53d..d574d9ea2 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -1056,77 +1056,12 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s #endif // __HIP_PLATFORM_AMD__ } -// Kernel that applies the 16x16 hadamard transform the input and input.T, and then -// gets the absolute max value of the result. +// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then +// get the absolute max value of the result. void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform_amax); -#ifdef __HIP_PLATFORM_AMD__ - NVTE_CHECK(input_.dtype() == DType::kBFloat16, "Input must be BF16."); - NVTE_CHECK(input_.dim() >= 2, "Input must be >=2D."); - - const SimpleTensor& input = input_.data; - SimpleTensor& pre_rht_tensor = output_.amax; - SimpleTensor identity_tensor; // Unused - SimpleTensor& transpose_tensor = output_.columnwise_amax; - - const bool want_pre_rht = (pre_rht_tensor.dptr != nullptr); - const bool want_identity = (identity_tensor.dptr != nullptr); - const bool want_trans = (transpose_tensor.dptr != nullptr); - - if (!want_pre_rht && !want_identity && !want_trans) - return; - - const size_t ndim = input.shape.size(); - const size_t row_length = input.shape[ndim - 1]; - size_t num_rows = 1; - - for (size_t i = 0; i < ndim - 1; ++i) - num_rows *= input.shape[i]; - - NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); - NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); - - auto* in_ptr = reinterpret_cast(input.dptr); - auto* pre_amax_ptr = reinterpret_cast(pre_rht_tensor.dptr); - auto* id_amax_ptr = reinterpret_cast(identity_tensor.dptr); - auto* tr_amax_ptr = reinterpret_cast(transpose_tensor.dptr); - - if (pre_amax_ptr) - NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); - if (id_amax_ptr) - NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); - if (tr_amax_ptr) - NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); - - if (want_pre_rht) { - const uint64_t num_elems = (uint64_t)num_rows * row_length; - dim3 g(DIVUP(num_elems, (uint64_t)kThreadsPerBlock)); - PreRhtAmaxKernel<<>>(in_ptr,pre_amax_ptr,num_elems); - NVTE_CHECK_CUDA(cudaGetLastError()); - } - - if (want_identity || want_trans) { - dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); -#define LAUNCH_A(IDENT,TRANS,UA,UAT) \ - HadamardTransformKernel \ - <<>>(in_ptr,nullptr,nullptr, \ - random_sign_mask,random_sign_mask_t, \ - (uint64_t)num_rows,(uint64_t)row_length, \ - id_amax_ptr,tr_amax_ptr,false) - - if (want_identity && want_trans) - LAUNCH_A(true, true, true, true); - else if (want_identity) - LAUNCH_A(true, false, true, false); - else - LAUNCH_A(false,true, false, true); - - NVTE_CHECK_CUDA(cudaGetLastError()); -#undef LAUNCH_A - } -#else // __HIP_PLATFORM_AMD__ -#if CUDA_VERSION >= 12080 +#if CUDA_VERSION >= 12080 || defined(__HIP_PLATFORM_AMD__) // Check input tensor NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, @@ -1151,16 +1086,6 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran return; } - // Zero out amaxes if needed - ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast(output_pre_rht_amax.dptr), - reinterpret_cast(output_identity_amax.dptr), - reinterpret_cast(output_transpose_amax.dptr)); - NVTE_CHECK_CUDA(cudaGetLastError()); - - checkCuDriverContext(stream); - - using IType = bf16; - const size_t ndim = input.shape.size(); const size_t row_length = input.shape[ndim - 1]; size_t num_rows = 1; @@ -1168,6 +1093,41 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran num_rows *= input.shape[i]; } +#ifdef __HIP_PLATFORM_AMD__ + auto* pre_amax_ptr = reinterpret_cast(output_pre_rht_amax.dptr); + auto* id_amax_ptr = reinterpret_cast(output_identity_amax.dptr); + auto* tr_amax_ptr = reinterpret_cast(output_transpose_amax.dptr); + + NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); + NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); + + auto* in_ptr = reinterpret_cast(input.dptr); + + if (pre_amax_ptr) { + NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); + } + if (id_amax_ptr) { + NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); + } + if (tr_amax_ptr) { + NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); + } + + if (return_pre_rht_amax) { + const uint64_t num_elems = static_cast(num_rows) * row_length; + dim3 grid(DIVUP(num_elems, static_cast(kThreadsPerBlock))); + PreRhtAmaxKernel<<>>(in_ptr, pre_amax_ptr, num_elems); + NVTE_CHECK_CUDA(cudaGetLastError()); + } +#else + // Zero out amaxes if needed + ZeroAmaxKernel<<<1, 1, 0, stream>>>(pre_amax_ptr, id_amax_ptr, tr_amax_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); + + checkCuDriverContext(stream); + + using IType = bf16; + constexpr int kHadamardDimension = 16; NVTE_CHECK(row_length % kHadamardDimension == 0, "row_length must be divisible by hadamard_dimension."); @@ -1200,6 +1160,7 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY); dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall)); +#endif TRANSFORMER_ENGINE_SWITCH_CONDITION( return_transposed_amax, kReturnTransposedAmax, @@ -1207,6 +1168,17 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran TRANSFORMER_ENGINE_SWITCH_CONDITION( return_identity_amax, kReturnIdentityAmax, +#ifdef __HIP_PLATFORM_AMD__ + if (kReturnIdentityAmax || kReturnTransposedAmax) { + dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); + HadamardTransformKernel + <<>>(in_ptr, nullptr, nullptr, random_sign_mask, + random_sign_mask_t, static_cast(num_rows), + static_cast(row_length), id_amax_ptr, + tr_amax_ptr, false); + } +#else TRANSFORMER_ENGINE_SWITCH_CONDITION( return_pre_rht_amax, kReturnPreRhtAmax, @@ -1230,13 +1202,14 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran reinterpret_cast(output_identity_amax.dptr), reinterpret_cast(output_transpose_amax.dptr), random_sign_mask, random_sign_mask_t, num_rows, row_length);))); +#endif + )); NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); -#endif // CUDA_VERSION >= 12080 -#endif // __HIP_PLATFORM_AMD__ +#endif // CUDA_VERSION >= 12080 || __HIP_PLATFORM_AMD__ } } // namespace transformer_engine