-
Notifications
You must be signed in to change notification settings - Fork 25
NVFP4 dequantization #505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
NVFP4 dequantization #505
Changes from all commits
b8a4024
0519b4b
8f4b04d
e60ff21
8bbb162
f573b40
eaaae94
8f94cf6
bac7993
8ae38e8
05a977a
0385852
46d382d
15416f1
bac5096
da24223
c03b7bb
c453dba
4a843ba
316dffb
8a47bc5
5c747bd
db56b8f
b318bda
62eea94
6d459ec
6eb2707
8cec975
c20e0e9
ccda439
4b0fd34
84934c2
e79134a
586bd09
4896edf
aa18e9a
c918a19
5bd7388
95d0c9f
6cd6038
55a8c84
10d88bf
b4caf6f
511db61
36cf73a
a85f68f
f4f5ec9
a607feb
ca2e444
5a5803c
d36ccbd
fc5af65
94a4e5e
82af544
56fefaf
dfd3205
a39e0d5
0b07970
645e37b
99fc99f
5f5dece
2b2ff5c
1f71218
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,235 @@ | ||||||
| /************************************************************************* | ||||||
| * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||||||
| * | ||||||
| * License for AMD contributions = MIT. See LICENSE for more information | ||||||
| ************************************************************************/ | ||||||
|
|
||||||
| #include <cuda_bf16.h> | ||||||
| #include <cuda_fp8.h> | ||||||
| #include <cuda_fp4.h> | ||||||
| #include <cuda_runtime.h> | ||||||
| #include <gtest/gtest.h> | ||||||
|
|
||||||
| #include <cstring> | ||||||
| #include <memory> | ||||||
| #include <random> | ||||||
| #include <vector> | ||||||
|
|
||||||
| #include <transformer_engine/cast.h> | ||||||
| #include <transformer_engine/activation.h> | ||||||
| #include "../test_common.h" | ||||||
| #include "transformer_engine/transformer_engine.h" | ||||||
|
|
||||||
| using namespace transformer_engine; | ||||||
| using namespace test; | ||||||
|
|
||||||
| namespace { | ||||||
|
|
||||||
| static constexpr float E2M1_LUT[16] = { | ||||||
| 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, | ||||||
| -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, | ||||||
| }; | ||||||
|
|
||||||
| // Generates random FP8 (E4M3) scale values by sampling raw 8-bit patterns. | ||||||
| // Each element is filled with a uniformly random byte [0–255], covering all | ||||||
| // possible FP8 encodings. Values are written using memcpy to preserve exact | ||||||
| // bit patterns rather than relying on numeric conversion. | ||||||
| void generate_scales(fp8e4m3* scales, | ||||||
| const size_t rows, | ||||||
| const size_t blocks_per_row, | ||||||
| const size_t scale_stride, | ||||||
| std::mt19937& gen, | ||||||
| std::uniform_int_distribution<int>& dis) { | ||||||
| for (size_t i = 0; i < rows; ++i) { | ||||||
| for (size_t j = 0; j < blocks_per_row; ++j) { | ||||||
| const size_t idx = i * scale_stride + j; | ||||||
| const uint8_t bits = static_cast<uint8_t>(dis(gen)); | ||||||
| std::memcpy(&scales[idx], &bits, sizeof(bits)); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Populate FP4 (E2M1) tensor using packed 4-bit encoding. | ||||||
| // Two values are stored per byte (lo/hi nibbles). Each nibble is sampled | ||||||
| // uniformly from [0, 15] and packed into a single byte. Requires cols to be even. | ||||||
| void generate_data(fp4e2m1* data, | ||||||
| const size_t rows, | ||||||
| const size_t cols, | ||||||
| std::mt19937& gen, | ||||||
| std::uniform_int_distribution<int>& dis) { | ||||||
| ASSERT_EQ(cols % 2, 0u); | ||||||
|
|
||||||
| auto* raw = reinterpret_cast<uint8_t*>(data); | ||||||
|
|
||||||
| for (size_t i = 0; i < rows; ++i) { | ||||||
| for (size_t j = 0; j < cols; j += 2) { | ||||||
| const size_t idx_pair = (i * cols + j) / 2; | ||||||
| const uint8_t lo = static_cast<uint8_t>(dis(gen)) & 0xF; | ||||||
| const uint8_t hi = static_cast<uint8_t>(dis(gen)) & 0xF; | ||||||
| raw[idx_pair] = static_cast<uint8_t>(lo | (hi << 4)); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Decode a single FP4 (E2M1) value from packed storage. | ||||||
| // Each byte contains two 4-bit values (nibbles). This extracts the appropriate | ||||||
| // nibble for the given logical index and converts it to float via a lookup table. | ||||||
| float get_fp4_value(const fp4e2m1* data, const size_t logical_idx) { | ||||||
| const auto* raw = reinterpret_cast<const uint8_t*>(data); | ||||||
| const size_t idx_pair = logical_idx / 2; | ||||||
| const uint8_t packed = raw[idx_pair]; | ||||||
| const uint8_t nibble = (logical_idx % 2 == 0) ? (packed & 0xF) : ((packed >> 4) & 0xF); | ||||||
| return E2M1_LUT[nibble]; | ||||||
| } | ||||||
|
|
||||||
| // Reference implementation: dequantize packed FP4 (E2M1) input using per-block FP8_E4M3 scales. | ||||||
| // Each block of 1x16 elements shares one scale; values are decoded to float and scaled, | ||||||
| // then written to output. | ||||||
| template <typename OutputType> | ||||||
| void compute_ref(const fp4e2m1* input, | ||||||
| OutputType* output, | ||||||
| const fp8e4m3* scales, | ||||||
| const float amax, | ||||||
| const size_t rows, | ||||||
| const size_t cols, | ||||||
| const size_t scale_stride) { | ||||||
| constexpr size_t block_size = 16; | ||||||
| constexpr float factor_inv = 1.0f / (6.0f * 448.0f); | ||||||
|
|
||||||
| const size_t blocks_per_row = cols / block_size; | ||||||
|
|
||||||
| for (size_t i = 0; i < rows; ++i) { | ||||||
| for (size_t b = 0; b < blocks_per_row; ++b) { | ||||||
| const float scale = | ||||||
| static_cast<float>(scales[i * scale_stride + b]) * amax * factor_inv; | ||||||
|
|
||||||
| for (size_t k = 0; k < block_size; ++k) { | ||||||
| const size_t col = b * block_size + k; | ||||||
| const size_t idx = i * cols + col; | ||||||
| const float x = get_fp4_value(input, idx); | ||||||
| output[idx] = static_cast<OutputType>(x * scale); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // End-to-end test: generate random FP4 input and FP8 scales, run device dequantization, | ||||||
| // compute reference on host, and compare results. | ||||||
| template <typename OutputType> | ||||||
| void performTest(const size_t rows, const size_t cols, DType otype) { | ||||||
| constexpr size_t block_size_1d = 16; | ||||||
| ASSERT_EQ(cols % block_size_1d, 0u); | ||||||
| ASSERT_EQ(cols % 2, 0u); | ||||||
|
|
||||||
| const DType itype = DType::kFloat4E2M1; | ||||||
| const size_t blocks_per_row = cols / block_size_1d; | ||||||
|
|
||||||
| Tensor input("input", std::vector<size_t>{rows, cols}, itype, | ||||||
| true, false, NVTE_NVFP4_1D_SCALING); | ||||||
| Tensor output("output", std::vector<size_t>{rows, cols}, otype, true, false); | ||||||
|
|
||||||
| const NVTEShape scale_shape = input.rowwise_scale_inv_shape(); | ||||||
| ASSERT_GE(scale_shape.ndim, 1u); | ||||||
|
|
||||||
| size_t scale_numel = 1; | ||||||
| for (size_t i = 0; i < scale_shape.ndim; ++i) { | ||||||
| scale_numel *= scale_shape.data[i]; | ||||||
| } | ||||||
| const size_t scale_stride = scale_shape.data[scale_shape.ndim - 1]; | ||||||
|
|
||||||
| const size_t data_bytes = (rows * cols * BitsNumber<fp4e2m1>::num_bits) / 8; | ||||||
| const size_t scale_bytes = scale_numel * sizeof(fp8e4m3); | ||||||
|
|
||||||
| std::unique_ptr<fp4e2m1[]> host_input = | ||||||
| std::make_unique<fp4e2m1[]>(rows * cols); | ||||||
| std::unique_ptr<fp8e4m3[]> host_scales = | ||||||
| std::make_unique<fp8e4m3[]>(scale_numel); | ||||||
| std::unique_ptr<OutputType[]> ref_output = | ||||||
| std::make_unique<OutputType[]>(rows * cols); | ||||||
|
|
||||||
| static std::mt19937 gen(42); | ||||||
| std::uniform_int_distribution<int> fp4_dis(0, 15); | ||||||
| std::uniform_int_distribution<int> fp8_dis(0, 255); | ||||||
|
|
||||||
| generate_data(host_input.get(), rows, cols, gen, fp4_dis); | ||||||
| generate_scales(host_scales.get(), | ||||||
|
Comment on lines
+154
to
+155
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to the layout alignment requirement, the data and scale for nvfp4 are not continuous in memory. Probably we can reuse the nvfp4 quantization here to generate a valid nvfp4 tensor |
||||||
| rows, | ||||||
| blocks_per_row, | ||||||
| scale_stride, | ||||||
| gen, | ||||||
| fp8_dis); | ||||||
|
|
||||||
| auto err = cudaMemcpy(input.rowwise_dptr(), host_input.get(), data_bytes, cudaMemcpyHostToDevice); | ||||||
| ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); | ||||||
|
|
||||||
| err = cudaMemcpy(input.rowwise_scale_inv_dptr(), host_scales.get(), scale_bytes, cudaMemcpyHostToDevice); | ||||||
| ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); | ||||||
|
|
||||||
| const float amax = 1.0f; | ||||||
| input.set_tensor_amax(amax); | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. set_scale() instead?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think for dequantization, the scale is needed |
||||||
|
|
||||||
| // Perform NVFP4 dequantization with device kernel | ||||||
| nvte_dequantize(input.data(), output.data(), 0); | ||||||
|
|
||||||
| cudaDeviceSynchronize(); | ||||||
| err = cudaGetLastError(); | ||||||
| ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); | ||||||
|
|
||||||
| output.to_cpu(); | ||||||
|
|
||||||
| // Perform NVFP4 dequantization ref on the host | ||||||
| compute_ref(host_input.get(), | ||||||
| ref_output.get(), | ||||||
| host_scales.get(), | ||||||
| amax, | ||||||
| rows, | ||||||
| cols, | ||||||
| scale_stride); | ||||||
|
|
||||||
| auto [atol, rtol] = getTolerances(otype); | ||||||
| compareResults("output", output, ref_output.get(), true, atol, rtol); | ||||||
| } | ||||||
|
|
||||||
| std::vector<std::pair<size_t, size_t>> tensor_dims = { | ||||||
| {32, 32}, | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like mxfp8, NV fp4 has its own scale_inv layout agreement for rowwise/colwise data: TransformerEngine/tests/cpp/test_common.h Line 348 in 98ccd2e
Take tensor dim {32,32} as an example, the rowwise scale inv will not be a continuous array for the first and the second row because nvfp4_scale_tensor_alignment_Y_rowwise=128, so padding is needed from 32/16=2 to 128 per row |
||||||
| {32, 64}, | ||||||
| {64, 32}, | ||||||
| {64, 96}, | ||||||
| {128, 128}, | ||||||
| {256, 256}, | ||||||
| {512, 512}, | ||||||
| {1024, 1024}, | ||||||
| {2048, 2048}, | ||||||
| }; | ||||||
|
|
||||||
| } // namespace | ||||||
|
|
||||||
| class DequantizeNVFP4TestSuite | ||||||
| : public ::testing::TestWithParam< | ||||||
| std::tuple<std::pair<size_t, size_t>, transformer_engine::DType>> {}; | ||||||
|
|
||||||
| TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { | ||||||
| const auto tensor_size = std::get<0>(GetParam()); | ||||||
| const DType output_type = std::get<1>(GetParam()); | ||||||
|
|
||||||
| const size_t rows = tensor_size.first; | ||||||
| const size_t cols = tensor_size.second; | ||||||
|
|
||||||
| TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( | ||||||
| output_type, OutputType, | ||||||
| performTest<OutputType>(rows, cols, output_type);); | ||||||
| } | ||||||
|
|
||||||
| INSTANTIATE_TEST_SUITE_P( | ||||||
| OperatorTest, | ||||||
| DequantizeNVFP4TestSuite, | ||||||
| ::testing::Combine( | ||||||
| ::testing::ValuesIn(tensor_dims), | ||||||
| ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), | ||||||
| [](const testing::TestParamInfo<DequantizeNVFP4TestSuite::ParamType>& info) { | ||||||
| std::string name = | ||||||
| std::to_string(std::get<0>(info.param).first) + "X" + | ||||||
| std::to_string(std::get<0>(info.param).second) + "X" + | ||||||
| test::typeName(std::get<1>(info.param)); | ||||||
| return name; | ||||||
| }); | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to also test with 2D scaling, and with columnwise data