diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index dfd8fba29..8a19e84f5 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -14,6 +14,7 @@ list(APPEND test_cuda_sources test_qdq.cu test_cast_mxfp8.cu test_dequantize_mxfp8.cu + test_dequantize_nvfp4.cu test_cast_nvfp4_transpose.cu test_transpose.cu test_cast_transpose.cu diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu new file mode 100644 index 000000000..6986e1333 --- /dev/null +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -0,0 +1,235 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +static constexpr float E2M1_LUT[16] = { + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, + -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, +}; + +// Generates random FP8 (E4M3) scale values by sampling raw 8-bit patterns. +// Each element is filled with a uniformly random byte [0–255], covering all +// possible FP8 encodings. Values are written using memcpy to preserve exact +// bit patterns rather than relying on numeric conversion. +void generate_scales(fp8e4m3* scales, + const size_t rows, + const size_t blocks_per_row, + const size_t scale_stride, + std::mt19937& gen, + std::uniform_int_distribution& dis) { + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < blocks_per_row; ++j) { + const size_t idx = i * scale_stride + j; + const uint8_t bits = static_cast(dis(gen)); + std::memcpy(&scales[idx], &bits, sizeof(bits)); + } + } +} + +// Populate FP4 (E2M1) tensor using packed 4-bit encoding. +// Two values are stored per byte (lo/hi nibbles). Each nibble is sampled +// uniformly from [0, 15] and packed into a single byte. Requires cols to be even. +void generate_data(fp4e2m1* data, + const size_t rows, + const size_t cols, + std::mt19937& gen, + std::uniform_int_distribution& dis) { + ASSERT_EQ(cols % 2, 0u); + + auto* raw = reinterpret_cast(data); + + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; j += 2) { + const size_t idx_pair = (i * cols + j) / 2; + const uint8_t lo = static_cast(dis(gen)) & 0xF; + const uint8_t hi = static_cast(dis(gen)) & 0xF; + raw[idx_pair] = static_cast(lo | (hi << 4)); + } + } +} + +// Decode a single FP4 (E2M1) value from packed storage. +// Each byte contains two 4-bit values (nibbles). This extracts the appropriate +// nibble for the given logical index and converts it to float via a lookup table. +float get_fp4_value(const fp4e2m1* data, const size_t logical_idx) { + const auto* raw = reinterpret_cast(data); + const size_t idx_pair = logical_idx / 2; + const uint8_t packed = raw[idx_pair]; + const uint8_t nibble = (logical_idx % 2 == 0) ? (packed & 0xF) : ((packed >> 4) & 0xF); + return E2M1_LUT[nibble]; +} + +// Reference implementation: dequantize packed FP4 (E2M1) input using per-block FP8_E4M3 scales. +// Each block of 1x16 elements shares one scale; values are decoded to float and scaled, +// then written to output. +template +void compute_ref(const fp4e2m1* input, + OutputType* output, + const fp8e4m3* scales, + const float amax, + const size_t rows, + const size_t cols, + const size_t scale_stride) { + constexpr size_t block_size = 16; + constexpr float factor_inv = 1.0f / (6.0f * 448.0f); + + const size_t blocks_per_row = cols / block_size; + + for (size_t i = 0; i < rows; ++i) { + for (size_t b = 0; b < blocks_per_row; ++b) { + const float scale = + static_cast(scales[i * scale_stride + b]) * amax * factor_inv; + + for (size_t k = 0; k < block_size; ++k) { + const size_t col = b * block_size + k; + const size_t idx = i * cols + col; + const float x = get_fp4_value(input, idx); + output[idx] = static_cast(x * scale); + } + } + } +} + +// End-to-end test: generate random FP4 input and FP8 scales, run device dequantization, +// compute reference on host, and compare results. +template +void performTest(const size_t rows, const size_t cols, DType otype) { + constexpr size_t block_size_1d = 16; + ASSERT_EQ(cols % block_size_1d, 0u); + ASSERT_EQ(cols % 2, 0u); + + const DType itype = DType::kFloat4E2M1; + const size_t blocks_per_row = cols / block_size_1d; + + Tensor input("input", std::vector{rows, cols}, itype, + true, false, NVTE_NVFP4_1D_SCALING); + Tensor output("output", std::vector{rows, cols}, otype, true, false); + + const NVTEShape scale_shape = input.rowwise_scale_inv_shape(); + ASSERT_GE(scale_shape.ndim, 1u); + + size_t scale_numel = 1; + for (size_t i = 0; i < scale_shape.ndim; ++i) { + scale_numel *= scale_shape.data[i]; + } + const size_t scale_stride = scale_shape.data[scale_shape.ndim - 1]; + + const size_t data_bytes = (rows * cols * BitsNumber::num_bits) / 8; + const size_t scale_bytes = scale_numel * sizeof(fp8e4m3); + + std::unique_ptr host_input = + std::make_unique(rows * cols); + std::unique_ptr host_scales = + std::make_unique(scale_numel); + std::unique_ptr ref_output = + std::make_unique(rows * cols); + + static std::mt19937 gen(42); + std::uniform_int_distribution fp4_dis(0, 15); + std::uniform_int_distribution fp8_dis(0, 255); + + generate_data(host_input.get(), rows, cols, gen, fp4_dis); + generate_scales(host_scales.get(), + rows, + blocks_per_row, + scale_stride, + gen, + fp8_dis); + + auto err = cudaMemcpy(input.rowwise_dptr(), host_input.get(), data_bytes, cudaMemcpyHostToDevice); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + err = cudaMemcpy(input.rowwise_scale_inv_dptr(), host_scales.get(), scale_bytes, cudaMemcpyHostToDevice); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + const float amax = 1.0f; + input.set_tensor_amax(amax); + + // Perform NVFP4 dequantization with device kernel + nvte_dequantize(input.data(), output.data(), 0); + + cudaDeviceSynchronize(); + err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + output.to_cpu(); + + // Perform NVFP4 dequantization ref on the host + compute_ref(host_input.get(), + ref_output.get(), + host_scales.get(), + amax, + rows, + cols, + scale_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), true, atol, rtol); +} + +std::vector> tensor_dims = { + {32, 32}, + {32, 64}, + {64, 32}, + {64, 96}, + {128, 128}, + {256, 256}, + {512, 512}, + {1024, 1024}, + {2048, 2048}, +}; + +} // namespace + +class DequantizeNVFP4TestSuite + : public ::testing::TestWithParam< + std::tuple, transformer_engine::DType>> {}; + +TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { + const auto tensor_size = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + + const size_t rows = tensor_size.first; + const size_t cols = tensor_size.second; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + output_type, OutputType, + performTest(rows, cols, output_type);); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DequantizeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(tensor_dims), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + [](const testing::TestParamInfo& info) { + std::string name = + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + test::typeName(std::get<1>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index e2bfdfd57..3028acdbe 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -301,6 +301,21 @@ class Tensor { tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } + void set_tensor_amax(float amax) { + if (!amax_cpu_data_) { + amax_cpu_data_ = std::make_shared(amax); + } else { + *amax_cpu_data_ = amax; + } + + float *amax_gpu = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax_gpu, sizeof(float))); + NVTE_CHECK_CUDA(cudaMemcpy(amax_gpu, amax_cpu_data_.get(), + sizeof(float), cudaMemcpyHostToDevice)); + + tensor_.set_amax(amax_gpu, DType::kFloat32, tensor_.defaultShape); + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); @@ -519,7 +534,7 @@ template void compare_scaling_factors(const std::string &name, const T *test, const T *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, #ifdef USE_ROCM - std::vector& mismatch_indices, + std::vector& mismatch_indices, #endif //#ifdef USE_ROCM size_t& mismatches_num, const size_t scale_diff_abs_tolerance = 0, diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 1bdd7e218..6b70c7582 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -18,9 +18,7 @@ #include "../../common.h" #include "../fp8/dequantize_fp8.cuh" #include "../mxfp8/dequantize_mxfp8.cuh" -#ifndef __HIP_PLATFORM_AMD__ #include "../nvfp4/dequantize_nvfp4.cuh" -#endif //#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace dispatch { @@ -49,12 +47,10 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t #endif //#ifndef __HIP_PLATFORM_AMD__ break; } -#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { nvfp4::dequantize(input, output, stream); break; } -#endif //#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 59742d1e7..bdcfd8d0b 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -157,12 +157,25 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); // for 2D block scaling, we need to reduce amax in warp +#ifdef __HIP_PLATFORM_AMD__ +static __device__ constexpr uint64_t WARP_REDUCE_AMAX_GROUP_MASKS[8] = { + 0x0101010101010101ULL, 0x0202020202020202ULL, + 0x0404040404040404ULL, 0x0808080808080808ULL, + 0x1010101010101010ULL, 0x2020202020202020ULL, + 0x4040404040404040ULL, 0x8080808080808080ULL}; +#else static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; +#endif // max for every group_size elements in warp template -__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) { +__device__ __forceinline__ float groupMax(float val, +#ifdef __HIP_PLATFORM_AMD__ + uint64_t groupMask) { +#else + unsigned int groupMask) { +#endif for (int offset = group_size / 2; offset > 0; offset /= 2) { #ifdef __HIP_PLATFORM_AMD__ (void)groupMask; // unused on AMD, __shfl_down does not take a mask