Conversation
Resolve wheels and examples
…ener/fp4-cast-transpose
…ener/fp4-cast-transpose
This reverts commit 5c747bd.
…ener/fp4-cast-transpose
| ASSERT_EQ(err, hipSuccess) << hipGetErrorString(err); | ||
|
|
||
| const float amax = 1.0f; | ||
| input.set_tensor_amax(amax); |
There was a problem hiding this comment.
set_scale() instead?
There was a problem hiding this comment.
Yeah, I think for dequantization, the scale is needed
ipanfilo
left a comment
There was a problem hiding this comment.
It is based on PR#472. Not to review the same changes twice let's wait for that PR to merge
| } | ||
|
|
||
| std::vector<std::pair<size_t, size_t>> tensor_dims = { | ||
| {32, 32}, |
There was a problem hiding this comment.
Like mxfp8, NV fp4 has its own scale_inv layout agreement for rowwise/colwise data:
TransformerEngine/tests/cpp/test_common.h
Line 348 in 98ccd2e
Take tensor dim {32,32} as an example, the rowwise scale inv will not be a continuous array for the first and the second row because nvfp4_scale_tensor_alignment_Y_rowwise=128, so padding is needed from 32/16=2 to 128 per row
| ASSERT_EQ(err, hipSuccess) << hipGetErrorString(err); | ||
|
|
||
| const float amax = 1.0f; | ||
| input.set_tensor_amax(amax); |
There was a problem hiding this comment.
Yeah, I think for dequantization, the scale is needed
| generate_data(host_input.get(), rows, cols, gen, fp4_dis); | ||
| generate_scales(host_scales.get(), |
There was a problem hiding this comment.
According to the layout alignment requirement, the data and scale for nvfp4 are not continuous in memory. Probably we can reuse the nvfp4 quantization here to generate a valid nvfp4 tensor
| const size_t blocks_per_row = cols / block_size_1d; | ||
|
|
||
| Tensor input("input", std::vector<size_t>{rows, cols}, itype, | ||
| true, false, NVTE_NVFP4_1D_SCALING); |
There was a problem hiding this comment.
Try to also test with 2D scaling, and with columnwise data
Description
Fixes https://github.com/ROCm/frameworks-internal/issues/15998
Enable NVFP4 dequantization on AMD GPU (gfx950) and add unit test.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: