Skip to content

NVFP4 dequantization#505

Open
aris134 wants to merge 63 commits intodevfrom
amartin/nvfp4-dequant
Open

NVFP4 dequantization#505
aris134 wants to merge 63 commits intodevfrom
amartin/nvfp4-dequant

Conversation

@aris134
Copy link
Copy Markdown

@aris134 aris134 commented Mar 25, 2026

Description

Fixes https://github.com/ROCm/frameworks-internal/issues/15998

Enable NVFP4 dequantization on AMD GPU (gfx950) and add unit test.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Enable compilation of NVFP4 dequantization kernel for AMD GPU
  • Add unit test that verifies NVFP4 dequantization works on gfx950

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@aris134 aris134 self-assigned this Mar 26, 2026
@aris134 aris134 marked this pull request as ready for review March 26, 2026 13:16
ASSERT_EQ(err, hipSuccess) << hipGetErrorString(err);

const float amax = 1.0f;
input.set_tensor_amax(amax);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_scale() instead?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think for dequantization, the scale is needed

Copy link
Copy Markdown
Collaborator

@ipanfilo ipanfilo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is based on PR#472. Not to review the same changes twice let's wait for that PR to merge

}

std::vector<std::pair<size_t, size_t>> tensor_dims = {
{32, 32},
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like mxfp8, NV fp4 has its own scale_inv layout agreement for rowwise/colwise data:

constexpr size_t nvfp4_scale_tensor_alignment_Y_rowwise = 128;
constexpr size_t scale_tensor_alignment_X_rowwise = 4;

Take tensor dim {32,32} as an example, the rowwise scale inv will not be a continuous array for the first and the second row because nvfp4_scale_tensor_alignment_Y_rowwise=128, so padding is needed from 32/16=2 to 128 per row

ASSERT_EQ(err, hipSuccess) << hipGetErrorString(err);

const float amax = 1.0f;
input.set_tensor_amax(amax);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think for dequantization, the scale is needed

Comment on lines +154 to +155
generate_data(host_input.get(), rows, cols, gen, fp4_dis);
generate_scales(host_scales.get(),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the layout alignment requirement, the data and scale for nvfp4 are not continuous in memory. Probably we can reuse the nvfp4 quantization here to generate a valid nvfp4 tensor

const size_t blocks_per_row = cols / block_size_1d;

Tensor input("input", std::vector<size_t>{rows, cols}, itype,
true, false, NVTE_NVFP4_1D_SCALING);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to also test with 2D scaling, and with columnwise data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants