Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
b8a4024
[ROCm] resolve the conflicts in common dir
wangye805 Feb 2, 2026
0519b4b
[ROCm] resolve the conflicts on jax side
wangye805 Feb 10, 2026
8f4b04d
[ROCm] resolve the conflicts on pytorch side
wangye805 Feb 10, 2026
e60ff21
[ROCm] resolve the conflicts in setup
wangye805 Feb 10, 2026
8bbb162
[ROCm] resolve the cpp gtest
wangye805 Feb 11, 2026
f573b40
[ROCm] resolve pytorch and jax tests
alextmagro Feb 11, 2026
eaaae94
pytest, example, wheels conflict resolution
alextmagro Feb 19, 2026
8f94cf6
jax and pytorch bugfix
alextmagro Feb 24, 2026
bac7993
copyrights and fp8_autocast->autocast fix
alextmagro Feb 24, 2026
8ae38e8
Enable test_distributed_dense.py
alextmagro Feb 24, 2026
05a977a
address IFU comments
alextmagro Mar 3, 2026
0385852
_FormatHelperFP8 and missing file add
alextmagro Mar 3, 2026
46d382d
add use_async_d2h_group_size as a test parameter
alextmagro Mar 3, 2026
15416f1
enable FP4 tests
matthiasdiener Mar 3, 2026
bac5096
rough initial version
matthiasdiener Mar 4, 2026
da24223
initial working version
matthiasdiener Mar 5, 2026
c03b7bb
Addressing comments and small fixes
alextmagro Mar 5, 2026
c453dba
various cleanups
matthiasdiener Mar 5, 2026
4a843ba
manually update runner labels
matthiasdiener Mar 5, 2026
316dffb
Comment cleanup
alextmagro Mar 5, 2026
8a47bc5
Merge remote-tracking branch 'origin/IFU-dev-20251114-v2.10' into mdi…
matthiasdiener Mar 5, 2026
5c747bd
only enable on gfx950
matthiasdiener Mar 5, 2026
db56b8f
Update jax gemm.py
alextmagro Mar 5, 2026
b318bda
Merge remote-tracking branch 'origin/IFU-dev-20251114-v2.10' into mdi…
matthiasdiener Mar 6, 2026
62eea94
Revert "only enable on gfx950"
matthiasdiener Mar 6, 2026
6d459ec
reenable in NVTEDType
matthiasdiener Mar 6, 2026
6eb2707
Fix dev merge conflicts
alextmagro Mar 6, 2026
8cec975
enable in bwd_helper
matthiasdiener Mar 6, 2026
c20e0e9
Merge remote-tracking branch 'origin/IFU-dev-20251114-v2.10' into mdi…
matthiasdiener Mar 9, 2026
ccda439
alignment fixes
matthiasdiener Mar 9, 2026
4b0fd34
fix merge error
matthiasdiener Mar 9, 2026
84934c2
minor fixes
matthiasdiener Mar 9, 2026
e79134a
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener Mar 11, 2026
586bd09
Run CI
leo-amd Mar 12, 2026
4896edf
Merge branch 'dev' into mdiener/fp4-cast-transpose
matthiasdiener Mar 12, 2026
aa18e9a
more scales fixing
matthiasdiener Mar 13, 2026
c918a19
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener Mar 13, 2026
5bd7388
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener Mar 16, 2026
95d0c9f
address review comments
matthiasdiener Mar 17, 2026
6cd6038
adjust error message slightly
matthiasdiener Mar 17, 2026
55a8c84
simplify via hipify map
matthiasdiener Mar 17, 2026
10d88bf
adjust more error messages
matthiasdiener Mar 17, 2026
b4caf6f
change disabling of header includes
matthiasdiener Mar 18, 2026
511db61
address review comments
matthiasdiener Mar 18, 2026
36cf73a
implement SR
matthiasdiener Mar 18, 2026
a85f68f
simplify slightly
matthiasdiener Mar 18, 2026
f4f5ec9
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener Mar 19, 2026
a607feb
address review comments
matthiasdiener Mar 19, 2026
ca2e444
bugfix arch SR support
matthiasdiener Mar 19, 2026
5a5803c
use scale constants
matthiasdiener Mar 20, 2026
d36ccbd
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener Mar 20, 2026
fc5af65
simplify to use __hip_fp4x4_storage_t directly
matthiasdiener Mar 20, 2026
94a4e5e
simplify storage for bit fiddling
matthiasdiener Mar 20, 2026
82af544
allow null amax in fallback kernel
matthiasdiener Mar 20, 2026
56fefaf
minor cleanup
matthiasdiener Mar 20, 2026
dfd3205
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener Mar 23, 2026
a39e0d5
Merge remote-tracking branch 'origin/dev' into mdiener/fp4-cast-trans…
matthiasdiener Mar 24, 2026
0b07970
enable nvfp::dequantize to be called on amd gpu
aris134 Mar 25, 2026
645e37b
Add NVFP4 dequant operator test
aris134 Mar 26, 2026
99fc99f
add EOL to test_dequantize_nvfp4.cu
aris134 Mar 26, 2026
5f5dece
simplify and add comments to the NVFP4 dequantization test
aris134 Mar 26, 2026
2b2ff5c
Merge branch 'dev' into amartin/nvfp4-dequant
aris134 Mar 26, 2026
1f71218
Update copy right string and replace hip prefixes
aris134 Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ list(APPEND test_cuda_sources
test_qdq.cu
test_cast_mxfp8.cu
test_dequantize_mxfp8.cu
test_dequantize_nvfp4.cu
test_cast_nvfp4_transpose.cu
test_transpose.cu
test_cast_transpose.cu
Expand Down
235 changes: 235 additions & 0 deletions tests/cpp/operator/test_dequantize_nvfp4.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
/*************************************************************************
* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/

#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_fp4.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <cstring>
#include <memory>
#include <random>
#include <vector>

#include <transformer_engine/cast.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"

using namespace transformer_engine;
using namespace test;

namespace {

static constexpr float E2M1_LUT[16] = {
0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f,
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f,
};

// Generates random FP8 (E4M3) scale values by sampling raw 8-bit patterns.
// Each element is filled with a uniformly random byte [0–255], covering all
// possible FP8 encodings. Values are written using memcpy to preserve exact
// bit patterns rather than relying on numeric conversion.
void generate_scales(fp8e4m3* scales,
const size_t rows,
const size_t blocks_per_row,
const size_t scale_stride,
std::mt19937& gen,
std::uniform_int_distribution<int>& dis) {
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < blocks_per_row; ++j) {
const size_t idx = i * scale_stride + j;
const uint8_t bits = static_cast<uint8_t>(dis(gen));
std::memcpy(&scales[idx], &bits, sizeof(bits));
}
}
}

// Populate FP4 (E2M1) tensor using packed 4-bit encoding.
// Two values are stored per byte (lo/hi nibbles). Each nibble is sampled
// uniformly from [0, 15] and packed into a single byte. Requires cols to be even.
void generate_data(fp4e2m1* data,
const size_t rows,
const size_t cols,
std::mt19937& gen,
std::uniform_int_distribution<int>& dis) {
ASSERT_EQ(cols % 2, 0u);

auto* raw = reinterpret_cast<uint8_t*>(data);

for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; j += 2) {
const size_t idx_pair = (i * cols + j) / 2;
const uint8_t lo = static_cast<uint8_t>(dis(gen)) & 0xF;
const uint8_t hi = static_cast<uint8_t>(dis(gen)) & 0xF;
raw[idx_pair] = static_cast<uint8_t>(lo | (hi << 4));
}
}
}

// Decode a single FP4 (E2M1) value from packed storage.
// Each byte contains two 4-bit values (nibbles). This extracts the appropriate
// nibble for the given logical index and converts it to float via a lookup table.
float get_fp4_value(const fp4e2m1* data, const size_t logical_idx) {
const auto* raw = reinterpret_cast<const uint8_t*>(data);
const size_t idx_pair = logical_idx / 2;
const uint8_t packed = raw[idx_pair];
const uint8_t nibble = (logical_idx % 2 == 0) ? (packed & 0xF) : ((packed >> 4) & 0xF);
return E2M1_LUT[nibble];
}

// Reference implementation: dequantize packed FP4 (E2M1) input using per-block FP8_E4M3 scales.
// Each block of 1x16 elements shares one scale; values are decoded to float and scaled,
// then written to output.
template <typename OutputType>
void compute_ref(const fp4e2m1* input,
OutputType* output,
const fp8e4m3* scales,
const float amax,
const size_t rows,
const size_t cols,
const size_t scale_stride) {
constexpr size_t block_size = 16;
constexpr float factor_inv = 1.0f / (6.0f * 448.0f);

const size_t blocks_per_row = cols / block_size;

for (size_t i = 0; i < rows; ++i) {
for (size_t b = 0; b < blocks_per_row; ++b) {
const float scale =
static_cast<float>(scales[i * scale_stride + b]) * amax * factor_inv;

for (size_t k = 0; k < block_size; ++k) {
const size_t col = b * block_size + k;
const size_t idx = i * cols + col;
const float x = get_fp4_value(input, idx);
output[idx] = static_cast<OutputType>(x * scale);
}
}
}
}

// End-to-end test: generate random FP4 input and FP8 scales, run device dequantization,
// compute reference on host, and compare results.
template <typename OutputType>
void performTest(const size_t rows, const size_t cols, DType otype) {
constexpr size_t block_size_1d = 16;
ASSERT_EQ(cols % block_size_1d, 0u);
ASSERT_EQ(cols % 2, 0u);

const DType itype = DType::kFloat4E2M1;
const size_t blocks_per_row = cols / block_size_1d;

Tensor input("input", std::vector<size_t>{rows, cols}, itype,
true, false, NVTE_NVFP4_1D_SCALING);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to also test with 2D scaling, and with columnwise data

Tensor output("output", std::vector<size_t>{rows, cols}, otype, true, false);

const NVTEShape scale_shape = input.rowwise_scale_inv_shape();
ASSERT_GE(scale_shape.ndim, 1u);

size_t scale_numel = 1;
for (size_t i = 0; i < scale_shape.ndim; ++i) {
scale_numel *= scale_shape.data[i];
}
const size_t scale_stride = scale_shape.data[scale_shape.ndim - 1];

const size_t data_bytes = (rows * cols * BitsNumber<fp4e2m1>::num_bits) / 8;
const size_t scale_bytes = scale_numel * sizeof(fp8e4m3);

std::unique_ptr<fp4e2m1[]> host_input =
std::make_unique<fp4e2m1[]>(rows * cols);
std::unique_ptr<fp8e4m3[]> host_scales =
std::make_unique<fp8e4m3[]>(scale_numel);
std::unique_ptr<OutputType[]> ref_output =
std::make_unique<OutputType[]>(rows * cols);

static std::mt19937 gen(42);
std::uniform_int_distribution<int> fp4_dis(0, 15);
std::uniform_int_distribution<int> fp8_dis(0, 255);

generate_data(host_input.get(), rows, cols, gen, fp4_dis);
generate_scales(host_scales.get(),
Comment on lines +154 to +155
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the layout alignment requirement, the data and scale for nvfp4 are not continuous in memory. Probably we can reuse the nvfp4 quantization here to generate a valid nvfp4 tensor

rows,
blocks_per_row,
scale_stride,
gen,
fp8_dis);

auto err = cudaMemcpy(input.rowwise_dptr(), host_input.get(), data_bytes, cudaMemcpyHostToDevice);
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

err = cudaMemcpy(input.rowwise_scale_inv_dptr(), host_scales.get(), scale_bytes, cudaMemcpyHostToDevice);
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

const float amax = 1.0f;
input.set_tensor_amax(amax);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_scale() instead?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think for dequantization, the scale is needed


// Perform NVFP4 dequantization with device kernel
nvte_dequantize(input.data(), output.data(), 0);

cudaDeviceSynchronize();
err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

output.to_cpu();

// Perform NVFP4 dequantization ref on the host
compute_ref(host_input.get(),
ref_output.get(),
host_scales.get(),
amax,
rows,
cols,
scale_stride);

auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), true, atol, rtol);
}

std::vector<std::pair<size_t, size_t>> tensor_dims = {
{32, 32},
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like mxfp8, NV fp4 has its own scale_inv layout agreement for rowwise/colwise data:

constexpr size_t nvfp4_scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_rowwise = 4;

Take tensor dim {32,32} as an example, the rowwise scale inv will not be a continuous array for the first and the second row because nvfp4_scale_tensor_alignment_Y_rowwise=128, so padding is needed from 32/16=2 to 128 per row

{32, 64},
{64, 32},
{64, 96},
{128, 128},
{256, 256},
{512, 512},
{1024, 1024},
{2048, 2048},
};

} // namespace

class DequantizeNVFP4TestSuite
: public ::testing::TestWithParam<
std::tuple<std::pair<size_t, size_t>, transformer_engine::DType>> {};

TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) {
const auto tensor_size = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());

const size_t rows = tensor_size.first;
const size_t cols = tensor_size.second;

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(
output_type, OutputType,
performTest<OutputType>(rows, cols, output_type););
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
DequantizeNVFP4TestSuite,
::testing::Combine(
::testing::ValuesIn(tensor_dims),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)),
[](const testing::TestParamInfo<DequantizeNVFP4TestSuite::ParamType>& info) {
std::string name =
std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
test::typeName(std::get<1>(info.param));
return name;
});
17 changes: 16 additions & 1 deletion tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,21 @@ class Tensor {
tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape);
}

void set_tensor_amax(float amax) {
if (!amax_cpu_data_) {
amax_cpu_data_ = std::make_shared<float>(amax);
} else {
*amax_cpu_data_ = amax;
}

float *amax_gpu = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&amax_gpu, sizeof(float)));
NVTE_CHECK_CUDA(cudaMemcpy(amax_gpu, amax_cpu_data_.get(),
sizeof(float), cudaMemcpyHostToDevice));

tensor_.set_amax(amax_gpu, DType::kFloat32, tensor_.defaultShape);
}

void to_cpu() const;
void from_cpu() const;
void set_scale(float scale);
Expand Down Expand Up @@ -519,7 +534,7 @@ template <typename T>
void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
const size_t row_blocks, const size_t col_blocks, const size_t stride,
#ifdef USE_ROCM
std::vector<size_t>& mismatch_indices,
std::vector<size_t>& mismatch_indices,
#endif //#ifdef USE_ROCM
size_t& mismatches_num,
const size_t scale_diff_abs_tolerance = 0,
Expand Down
4 changes: 0 additions & 4 deletions transformer_engine/common/cast/dispatch/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
#include "../../common.h"
#include "../fp8/dequantize_fp8.cuh"
#include "../mxfp8/dequantize_mxfp8.cuh"
#ifndef __HIP_PLATFORM_AMD__
#include "../nvfp4/dequantize_nvfp4.cuh"
#endif //#ifndef __HIP_PLATFORM_AMD__

namespace transformer_engine {
namespace dispatch {
Expand Down Expand Up @@ -49,12 +47,10 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t
#endif //#ifndef __HIP_PLATFORM_AMD__
break;
}
#ifndef __HIP_PLATFORM_AMD__
case NVTE_NVFP4_1D_SCALING: {
nvfp4::dequantize(input, output, stream);
break;
}
#endif //#ifndef __HIP_PLATFORM_AMD__
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,25 @@ static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kT
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");

// for 2D block scaling, we need to reduce amax in warp
#ifdef __HIP_PLATFORM_AMD__
static __device__ constexpr uint64_t WARP_REDUCE_AMAX_GROUP_MASKS[8] = {
0x0101010101010101ULL, 0x0202020202020202ULL,
0x0404040404040404ULL, 0x0808080808080808ULL,
0x1010101010101010ULL, 0x2020202020202020ULL,
0x4040404040404040ULL, 0x8080808080808080ULL};
#else
static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = {
0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080};
#endif

// max for every group_size elements in warp
template <int group_size, int shfl_down_stride>
__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) {
__device__ __forceinline__ float groupMax(float val,
#ifdef __HIP_PLATFORM_AMD__
uint64_t groupMask) {
#else
unsigned int groupMask) {
#endif
for (int offset = group_size / 2; offset > 0; offset /= 2) {
#ifdef __HIP_PLATFORM_AMD__
(void)groupMask; // unused on AMD, __shfl_down does not take a mask
Expand Down