diff --git a/build_tools/hipify/custom_map.json b/build_tools/hipify/custom_map.json index 872d38efa..92b3f0a44 100644 --- a/build_tools/hipify/custom_map.json +++ b/build_tools/hipify/custom_map.json @@ -15,7 +15,8 @@ "__nv_fp4x4_e2m1" : "__hip_fp4x4_e2m1", "__nv_fp4x2_storage_t" : "__hip_fp4x2_storage_t", "#include " : "", - "#include " : "" + "#include " : "", + "#include " : "" } } diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 98be9a4f5..5fd55e059 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -246,3 +248,73 @@ def test_nvfp4_quantization_noncontiguous_inputs( use_cpp_allocator=use_cpp_allocator, with_random_sign_mask=with_random_sign_mask, ) + + +def _ref_wht16_tiled(x: torch.Tensor, sign_mask: int) -> torch.Tensor: + """Reference 16-point WHT tiled along last dim, normalised by 0.25.""" + x = x.float() + _rows, cols = x.shape + d = torch.tensor( + [((-1) ** ((sign_mask >> i) & 1)) for i in range(16)], + dtype=torch.float32, device=x.device, + ) + out = x.clone() + for c in range(0, cols, 16): + tile = out[:, c:c+16] * d # apply sign + h = 1 + while h < 16: + for i in range(0, 16, h * 2): + a = tile[:, i:i+h].clone() + b = tile[:, i+h:i+2*h].clone() + tile[:, i:i+h] = a + b + tile[:, i+h:i+2*h] = a - b + h *= 2 + out[:, c:c+16] = tile * 0.25 + return out + + +@pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)]) +def test_hadamard_transform_amax(rows, cols): + """ + Tests nvte_hadamard_transform_amax via NVFP4Quantizer (with_rht=True). + Exercises the WHT kernel without requiring a full NVFP4 recipe. + Checks: + - amax_rowwise == max|x| (pre-RHT amax of raw input) + - amax_colwise == max|WHT(x.T)| (post-RHT amax of transposed input) + """ + torch.manual_seed(42) + x = torch.randn((rows, cols), dtype=torch.bfloat16, device="cuda").contiguous() + + quantizer = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + ) + out = quantizer(x) + + # amax_rowwise: pre-RHT, should equal max|x| + expected_rowwise_amax = x.float().abs().max() + torch.testing.assert_close( + out._amax_rowwise.float().squeeze(), + expected_rowwise_amax, + rtol=1e-3, atol=1e-3, + msg=f"pre-RHT amax mismatch rows={rows} cols={cols}", + ) + + # amax_colwise: post-RHT of x.T, should equal max|WHT(x.T)| + sign_mask_t = quantizer.rht_matrix_random_sign_mask_t + x_t = x.t().contiguous() # (cols, rows) + wht_x_t = _ref_wht16_tiled(x_t, sign_mask=sign_mask_t) + expected_colwise_amax = wht_x_t.float().abs().max() + + torch.testing.assert_close( + out._amax_columnwise.float().squeeze().item(), + float(expected_colwise_amax), + rtol=2e-2, atol=2e-2, + msg=f"post-RHT amax mismatch rows={rows} cols={cols}", + ) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 8d5537368..847fbcb8e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -223,6 +223,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources activation/gelu.cu activation/relu.cu activation/swiglu.cu + hadamard_transform/hadamard_transform.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) if(USE_CUDA) @@ -247,7 +248,6 @@ if(USE_CUDA) list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu transpose/quantize_transpose_square_blockwise.cu - hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform_cast_fusion.cu) else() #ROCm specific source codes diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index b901f9023..d574d9ea2 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -23,6 +25,8 @@ namespace { constexpr int kThreadsPerWarp = 32; constexpr float k16x16HadamardScale = 0.25f; +#ifndef __HIP_PLATFORM_AMD__ + template __device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, @@ -658,12 +662,325 @@ __global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restri #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 } +#endif // __HIP_PLATFORM_AMD__ +} // namespace + +#ifdef __HIP_PLATFORM_AMD__ + +namespace { + +static constexpr int kHadamardDim = 16; +static constexpr int kWarpSize = 64; +static constexpr int kThreadsPerWHT = 4; +static constexpr int kElemsPerThread = 4; +static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16 +static constexpr int kWarpsPerBlock = 4; +static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 +static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 +static constexpr float kHadamardScale = 0.25f; + +// ds_swizzle: sub-wavefront exchange without LDS. +// Same instructions as cast_transpose_mxfp4_kernel_shuffled.cu. +__device__ __forceinline__ float ds_swizzle_xor1(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +__device__ __forceinline__ float ds_swizzle_xor2(float v) { + float r; + asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t" + "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); + return r; +} + +// BF16 helpers +__device__ __forceinline__ float to_f32 (__hip_bfloat16 v) { return static_cast(v); } +__device__ __forceinline__ __hip_bfloat16 to_bf16(float v) { return static_cast<__hip_bfloat16>(v); } + +// Bit-cast __hip_bfloat16->uint16_t without address-of-temporary. +__device__ __forceinline__ uint16_t bf16_to_bits(__hip_bfloat16 v) { + uint16_t bits; __builtin_memcpy(&bits, &v, sizeof(uint16_t)); return bits; +} + +// Unpack/pack 4 BF16 values as uint64_t (vectorised global load/store). +// Same trick as cast_transpose_mxfp4_kernel_shuffled.cu::bf16x4_to_float4. +__device__ __forceinline__ void unpack_bf16x4(uint64_t p, + float& v0, float& v1, float& v2, float& v3) { + v0 = __uint_as_float(((uint32_t)( p & 0xFFFF)) << 16); + v1 = __uint_as_float(((uint32_t)((p >> 16) & 0xFFFF)) << 16); + v2 = __uint_as_float(((uint32_t)((p >> 32) & 0xFFFF)) << 16); + v3 = __uint_as_float(((uint32_t)((p >> 48) & 0xFFFF)) << 16); +} + +__device__ __forceinline__ uint64_t pack_bf16x4(float v0, float v1, float v2, float v3) { + return (uint64_t)bf16_to_bits(to_bf16(v0)) + | ((uint64_t)bf16_to_bits(to_bf16(v1)) << 16) + | ((uint64_t)bf16_to_bits(to_bf16(v2)) << 32) + | ((uint64_t)bf16_to_bits(to_bf16(v3)) << 48); +} + +// 16-point WHT: in-register, no shared memory. +// Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace, +// extended with NV random_sign_mask (uint16_t bitmask). +// thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3). +// apply_pre=true -> D before WHT (forward); false -> D after WHT (inverse). +__device__ __forceinline__ void wht16( + float& v0, float& v1, float& v2, float& v3, + int thread_in_group, uint16_t sign_mask, bool apply_pre) { + auto sgn = [&](int k) -> float { + return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f; + }; + + if (apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } + + // Stage 1: local H4 + float a0=v0+v1, a1=v0-v1, a2=v2+v3, a3=v2-v3; + v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3; + + // Stage 2: cross-thread XOR-1 + { float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1), + p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3); + bool up=(thread_in_group&1); + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + // Stage 3: cross-thread XOR-2 + { float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1), + p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3); + bool up=(thread_in_group>>1)&1; + v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); + v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } + + v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale; + + if (!apply_pre) { + v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); + } +} + +// Grid: blockIdx.x = col tile [0, row_length/16) +// blockIdx.y = row batch [0, ceil(num_rows/64)) +// Block: 256 threads = 4 wavefronts of 64 lanes. +// lane/4 = row_in_warp (0..15), lane%4 = thread_in_grp (0..3) +template +__global__ __launch_bounds__(kThreadsPerBlock, 4) +void HadamardTransformKernel( + const __hip_bfloat16* __restrict__ input, + __hip_bfloat16* __restrict__ output, + __hip_bfloat16* __restrict__ output_t, + uint16_t random_sign_mask, uint16_t random_sign_mask_t, + uint64_t num_rows, uint64_t row_length, + float* __restrict__ amax, float* __restrict__ amax_t, + bool inverse_hadamard) { + const int tid = threadIdx.x; + const int warp_id = tid / kWarpSize; + const int lane_id = tid % kWarpSize; + const int row_in_warp = lane_id / kThreadsPerWHT; + const int thread_in_grp = lane_id % kThreadsPerWHT; + const uint64_t col_tile_base = (uint64_t)blockIdx.x * kHadamardDim; + const uint64_t row_batch = (uint64_t)blockIdx.y * kRowsPerBlock; + const uint64_t global_row = row_batch + (uint64_t)warp_id*kRowsPerWarp + row_in_warp; + const uint64_t col_base = col_tile_base + (uint64_t)thread_in_grp * kElemsPerThread; + + const bool apply_pre = !inverse_hadamard; + const bool in_bounds = (global_row < num_rows) && (col_base + kElemsPerThread - 1 < row_length); + + // Smem for transposed path: 64*(16+1) BF16; +1 avoids LDS bank conflict. + __shared__ __hip_bfloat16 smem[kRowsPerBlock][kHadamardDim + 1]; + + __shared__ float block_amax[kWarpsPerBlock]; + __shared__ float block_amax_t[kWarpsPerBlock]; + + float v0=0.f, v1=0.f, v2=0.f, v3=0.f; + if (in_bounds) { + unpack_bf16x4(*reinterpret_cast( + &input[global_row * row_length + col_base]), v0, v1, v2, v3); + } + + // Identity path: WHT along row dimension + if constexpr (kComputeIdentity || kUpdateAmax) { + float r0=v0, r1=v1, r2=v2, r3=v3; + float lam = 0.f; + if (global_row < num_rows) { + wht16(r0, r1, r2, r3, thread_in_grp, random_sign_mask, apply_pre); + if constexpr (kUpdateAmax) { + lam = fmaxf(fmaxf(fabsf(r0),fabsf(r1)),fmaxf(fabsf(r2),fabsf(r3))); + for (int off=kWarpSize/2; off>=1; off>>=1) + lam=fmaxf(lam,__shfl_xor(lam,off)); + } + if constexpr (kComputeIdentity) + if (output && in_bounds) + *reinterpret_cast(&output[global_row*row_length+col_base]) = + pack_bf16x4(r0,r1,r2,r3); + } + if constexpr (kUpdateAmax) { + if (lane_id == 0) + block_amax[warp_id] = lam; + } + } + + // Transposed path: WHT along column dimension via smem transpose + if constexpr (kComputeTransposed || kUpdateAmaxT) { + const int local_row = warp_id * kRowsPerWarp + row_in_warp; + const int col_offset = thread_in_grp * kElemsPerThread; + float lam = 0.f; + smem[local_row][col_offset+0] = to_bf16(global_row < num_rows ? v0 : 0.f); + smem[local_row][col_offset+1] = to_bf16(global_row < num_rows ? v1 : 0.f); + smem[local_row][col_offset+2] = to_bf16(global_row < num_rows ? v2 : 0.f); + smem[local_row][col_offset+3] = to_bf16(global_row < num_rows ? v3 : 0.f); + __syncthreads(); + + // Re-read: row_in_warp -> column index, thread_in_grp -> 4 rows + const int t_col = row_in_warp; + const int smem_rbase = warp_id*kRowsPerWarp + thread_in_grp*kElemsPerThread; + + float c0=to_f32(smem[smem_rbase+0][t_col]), c1=to_f32(smem[smem_rbase+1][t_col]); + float c2=to_f32(smem[smem_rbase+2][t_col]), c3=to_f32(smem[smem_rbase+3][t_col]); + + wht16(c0, c1, c2, c3, thread_in_grp, random_sign_mask_t, apply_pre); + + if constexpr (kUpdateAmaxT) { + lam = fmaxf(fmaxf(fabsf(c0),fabsf(c1)),fmaxf(fabsf(c2),fabsf(c3))); + + for (int off=kWarpSize/2; off>=1; off>>=1) + lam=fmaxf(lam,__shfl_xor(lam,off)); + } + + if constexpr (kComputeTransposed) { + if (output_t) { + const uint64_t global_col = col_tile_base + t_col; + const uint64_t out_row_base = row_batch + (uint64_t)warp_id*kRowsPerWarp + + (uint64_t)thread_in_grp*kElemsPerThread; + if (global_col < row_length && out_row_base+kElemsPerThread-1 < num_rows) + *reinterpret_cast( + &output_t[global_col*num_rows+out_row_base]) = + pack_bf16x4(c0,c1,c2,c3); + } + } + + if constexpr (kUpdateAmaxT) { + if (lane_id == 0) + block_amax_t[warp_id] = lam; + } + } + + if constexpr (kUpdateAmax || kUpdateAmaxT) { + __syncthreads(); + + if (warp_id == 0) { + if constexpr (kUpdateAmax) { + float block_lam = (lane_id < kWarpsPerBlock) ? block_amax[lane_id] : 0.f; + + for (int off=kWarpSize/2; off>=1; off>>=1) + block_lam = fmaxf(block_lam, __shfl_xor(block_lam, off)); + + if (lane_id == 0) + atomicMaxFloat(amax, block_lam); + } + + if constexpr (kUpdateAmaxT) { + float block_lam_t = (lane_id < kWarpsPerBlock) ? block_amax_t[lane_id] : 0.f; + + for (int off=kWarpSize/2; off>=1; off>>=1) + block_lam_t = fmaxf(block_lam_t, __shfl_xor(block_lam_t, off)); + + if (lane_id == 0) + atomicMaxFloat(amax_t, block_lam_t); + } + } + } +} + +// Pre-RHT amax: max|input| before any transform. +__global__ void PreRhtAmaxKernel(const __hip_bfloat16* __restrict__ input, + float* __restrict__ amax_out, uint64_t num_elems) { + __shared__ float block_amax[kWarpsPerBlock]; + float lam = 0.f; + for (uint64_t i = (uint64_t)blockIdx.x*blockDim.x+threadIdx.x; + i < num_elems; i += (uint64_t)gridDim.x*blockDim.x) + lam = fmaxf(lam, fabsf(to_f32(input[i]))); + + for (int off=kWarpSize/2; off>=1; off>>=1) + lam=fmaxf(lam,__shfl_xor(lam,off)); + + const int warp_id = threadIdx.x / kWarpSize; + const int lane_id = threadIdx.x % kWarpSize; + if (lane_id == 0) + block_amax[warp_id] = lam; + + __syncthreads(); + + if (warp_id == 0) { + float block_lam = (lane_id < kWarpsPerBlock) ? block_amax[lane_id] : 0.f; + for (int off=kWarpSize/2; off>=1; off>>=1) + block_lam=fmaxf(block_lam,__shfl_xor(block_lam,off)); + + if (lane_id == 0) + atomicMaxFloat(amax_out, block_lam); + } +} + +static inline dim3 transform_grid(uint64_t num_rows, uint64_t row_length) { + return dim3((uint32_t)(row_length / kHadamardDim), + (uint32_t)DIVUP(num_rows, (uint64_t)kRowsPerBlock)); +} + } // namespace +#endif // __HIP_PLATFORM_AMD__ + void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform); +#ifdef __HIP_PLATFORM_AMD__ + NVTE_CHECK(input_.dtype() == DType::kBFloat16, "Input must be BF16."); + NVTE_CHECK(input_.dim() >= 2, "Input must be >=2D."); + + const SimpleTensor& input = input_.data; + SimpleTensor identity_out; // Unused + SimpleTensor& transposed_out = output_.data; + + const bool want_identity = (identity_out.dptr != nullptr); + const bool want_transposed = (transposed_out.dptr != nullptr); + + if (!want_identity && !want_transposed) + return; + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + + for (size_t i = 0; i < ndim - 1; ++i) + num_rows *= input.shape[i]; + NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); + NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); + + auto* in_ptr = reinterpret_cast(input.dptr); + auto* id_ptr = reinterpret_cast<__hip_bfloat16*>(identity_out.dptr); + auto* tr_ptr = reinterpret_cast<__hip_bfloat16*>(transposed_out.dptr); + dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); + +#define LAUNCH_T(IDENT, TRANS) \ + HadamardTransformKernel \ + <<>>(in_ptr, id_ptr, tr_ptr, \ + random_sign_mask, random_sign_mask_t, \ + (uint64_t)num_rows, (uint64_t)row_length, nullptr, nullptr, false) + + if (want_identity && want_transposed) + LAUNCH_T(true, true); + else if (want_identity) + LAUNCH_T(true, false); + else + LAUNCH_T(false, true); + NVTE_CHECK_CUDA(cudaGetLastError()); +#undef LAUNCH_T +#else // CUDA // Check tensors // NOTE (frsun): This is non-intuitive, we are writing the result of // transposed RHT to the output of rowwise. @@ -736,6 +1053,7 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s num_rows, row_length, nullptr, nullptr, false););); NVTE_CHECK_CUDA(cudaGetLastError()); +#endif // __HIP_PLATFORM_AMD__ } // Kernel that will apply the 16x16 hadamard transform the input and input.T, and then @@ -743,7 +1061,7 @@ void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_s void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, uint16_t random_sign_mask_t, cudaStream_t stream) { NVTE_API_CALL(hadamard_transform_amax); -#if CUDA_VERSION >= 12080 +#if CUDA_VERSION >= 12080 || defined(__HIP_PLATFORM_AMD__) // Check input tensor NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, @@ -768,16 +1086,6 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran return; } - // Zero out amaxes if needed - ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast(output_pre_rht_amax.dptr), - reinterpret_cast(output_identity_amax.dptr), - reinterpret_cast(output_transpose_amax.dptr)); - NVTE_CHECK_CUDA(cudaGetLastError()); - - checkCuDriverContext(stream); - - using IType = bf16; - const size_t ndim = input.shape.size(); const size_t row_length = input.shape[ndim - 1]; size_t num_rows = 1; @@ -785,6 +1093,41 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran num_rows *= input.shape[i]; } +#ifdef __HIP_PLATFORM_AMD__ + auto* pre_amax_ptr = reinterpret_cast(output_pre_rht_amax.dptr); + auto* id_amax_ptr = reinterpret_cast(output_identity_amax.dptr); + auto* tr_amax_ptr = reinterpret_cast(output_transpose_amax.dptr); + + NVTE_CHECK(row_length % kHadamardDim == 0, "row_length must be divisible by 16."); + NVTE_CHECK(num_rows % kHadamardDim == 0, "num_rows must be divisible by 16."); + + auto* in_ptr = reinterpret_cast(input.dptr); + + if (pre_amax_ptr) { + NVTE_CHECK_CUDA(cudaMemsetAsync(pre_amax_ptr, 0, sizeof(float), stream)); + } + if (id_amax_ptr) { + NVTE_CHECK_CUDA(cudaMemsetAsync(id_amax_ptr, 0, sizeof(float), stream)); + } + if (tr_amax_ptr) { + NVTE_CHECK_CUDA(cudaMemsetAsync(tr_amax_ptr, 0, sizeof(float), stream)); + } + + if (return_pre_rht_amax) { + const uint64_t num_elems = static_cast(num_rows) * row_length; + dim3 grid(DIVUP(num_elems, static_cast(kThreadsPerBlock))); + PreRhtAmaxKernel<<>>(in_ptr, pre_amax_ptr, num_elems); + NVTE_CHECK_CUDA(cudaGetLastError()); + } +#else + // Zero out amaxes if needed + ZeroAmaxKernel<<<1, 1, 0, stream>>>(pre_amax_ptr, id_amax_ptr, tr_amax_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); + + checkCuDriverContext(stream); + + using IType = bf16; + constexpr int kHadamardDimension = 16; NVTE_CHECK(row_length % kHadamardDimension == 0, "row_length must be divisible by hadamard_dimension."); @@ -817,6 +1160,7 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY); dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall)); +#endif TRANSFORMER_ENGINE_SWITCH_CONDITION( return_transposed_amax, kReturnTransposedAmax, @@ -824,6 +1168,17 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran TRANSFORMER_ENGINE_SWITCH_CONDITION( return_identity_amax, kReturnIdentityAmax, +#ifdef __HIP_PLATFORM_AMD__ + if (kReturnIdentityAmax || kReturnTransposedAmax) { + dim3 grid = transform_grid(num_rows, row_length), block(kThreadsPerBlock); + HadamardTransformKernel + <<>>(in_ptr, nullptr, nullptr, random_sign_mask, + random_sign_mask_t, static_cast(num_rows), + static_cast(row_length), id_amax_ptr, + tr_amax_ptr, false); + } +#else TRANSFORMER_ENGINE_SWITCH_CONDITION( return_pre_rht_amax, kReturnPreRhtAmax, @@ -847,12 +1202,14 @@ void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t ran reinterpret_cast(output_identity_amax.dptr), reinterpret_cast(output_transpose_amax.dptr), random_sign_mask, random_sign_mask_t, num_rows, row_length);))); +#endif + )); NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); -#endif // CUDA_VERSION >= 12080 +#endif // CUDA_VERSION >= 12080 || __HIP_PLATFORM_AMD__ } } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index 6785040df..90a722e66 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -13,8 +13,6 @@ #ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ #define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ -#ifndef __HIP_PLATFORM_AMD__ - #include "transformer_engine.h" #ifdef __cplusplus @@ -69,6 +67,4 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE } // extern "C" #endif -#endif - #endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 27f24a961..6c19bae13 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -293,7 +293,6 @@ class MXFP8Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; -#ifndef USE_ROCM class NVFP4Quantizer : public Quantizer { public: // fp4 dtype @@ -347,7 +346,6 @@ class NVFP4Quantizer : public Quantizer { void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; -#endif // #ifndef USE_ROCM std::unique_ptr convert_quantizer(py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index d5fd4a4fe..b924e8a77 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -110,12 +110,8 @@ constexpr std::array custom_types_converters = { CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), -#ifdef USE_ROCM -}; -#else std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, CreateQuantizer)}; -#endif } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 50c6bc810..7ca0ab3bf 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1163,7 +1163,6 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s #endif } -#ifndef USE_ROCM NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->with_rht = quantizer.attr("with_rht").cast(); @@ -1516,6 +1515,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; +#ifdef USE_ROCM + eligible_for_rht_cast_fusion = false; +#endif + // Compute amax. if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { @@ -1663,6 +1666,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream); }); } else { +#ifndef USE_ROCM // RHT cast fusion kernel. NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, "RHT matrix is not set"); @@ -1671,6 +1675,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou nvte_hadamard_transform_cast_fusion_columnwise( input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream); }); +#endif } } } else { @@ -1740,6 +1745,5 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s } return scale_shape; } -#endif } // namespace transformer_engine::pytorch