From 869147e953a1f1b13ebaa9df67f8082e2d219b94 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 4 Dec 2025 13:34:15 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 1 + test/prototype/mx_formats/test_nvfp4_tensor.py | 18 ++++++++++++++++-- torchao/prototype/mx_formats/kernels.py | 5 +++-- torchao/prototype/mx_formats/mx_tensor.py | 4 +++- torchao/prototype/mx_formats/nvfp4_tensor.py | 7 +++---- 5 files changed, 26 insertions(+), 9 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 2b8c72ff91..f99488b2f1 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -76,6 +76,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): prev_dims, K = data_hp.shape[:-1], data_hp.shape[-1] if elem_dtype is torch.float4_e2m1fn_x2: assert data_mx.qdata.shape == (*prev_dims, K // 2) + assert data_mx.qdata.dtype == torch.float4_e2m1fn_x2 else: assert data_mx.qdata.shape == (*prev_dims, K) assert data_mx.scale.shape == (*prev_dims, K // block_size) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 2f734cef2c..45bb484e9f 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -392,8 +392,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): ) torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten()) - pt_unpacked = unpack_uint4(nvfp4_pt.qdata) - triton_unpacked = unpack_uint4(nvfp4_triton.qdata) + pt_unpacked = unpack_uint4(nvfp4_pt.qdata.view(torch.uint8)) + triton_unpacked = unpack_uint4(nvfp4_triton.qdata.view(torch.uint8)) torch.testing.assert_close( pt_unpacked, triton_unpacked, @@ -611,3 +611,17 @@ def test_3d_transpose(dims, is_swizzled_scales): x_hp_t = x_hp.transpose(dims[0], dims[1]) x_nvfp4_t = x_nvfp4.transpose(dims[0], dims[1]) assert x_hp_t.shape == x_nvfp4_t.shape + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +@pytest.mark.parametrize("use_triton_kernel", [False, True]) +def test_uses_fp4_qdata(use_triton_kernel): + x_hp = torch.randn(2, 128, 256, device="cuda") + # TODO also test triton kernel + x_nvfp4 = NVFP4Tensor.to_nvfp4( + x_hp, use_triton_kernel=use_triton_kernel, is_swizzled_scales=True + ) + assert x_nvfp4.qdata.dtype == torch.float4_e2m1fn_x2 diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index b4cd192244..6758dddbca 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1028,8 +1028,9 @@ def triton_quantize_nvfp4( # reshape back to original shape scales = scales.view(*orig_leading_dims, -1, padded_cols) xq = xq.view(*orig_leading_dims, -1, N // 2) + xq = xq.view(torch.float4_e2m1fn_x2) - return scales, xq.view(torch.uint8) + return scales, xq @triton_quantize_nvfp4.register_fake def _(x, per_tensor_scale=None): @@ -1043,7 +1044,7 @@ def _(x, per_tensor_scale=None): scales = torch.empty( padded_rows, padded_cols, device=x.device, dtype=torch.float8_e4m3fn ) - xq = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8) + xq = torch.empty(M, N // 2, device=x.device, dtype=torch.float4_e2m1fn_x2) return scales, xq @triton_mx_block_rearrange.register_fake diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index e9f7225647..4bee72cad5 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -321,6 +321,7 @@ def to_mx( data_lp = data_lp.reshape(orig_shape) data_lp = f32_to_f4_unpacked(data_lp) data_lp = pack_uint4(data_lp) + data_lp = data_lp.view(torch.float4_e2m1fn_x2) else: raise AssertionError("unsupported") @@ -382,7 +383,7 @@ def to_dtype( data_hp = data_hp.to(target_dtype).reshape(orig_shape) elif elem_dtype == torch.float4_e2m1fn_x2: # fp4 - f4_unpacked = unpack_uint4(data_lp) + f4_unpacked = unpack_uint4(data_lp.view(torch.uint8)) # for now we only have a cast to f32 # TODO(future PR): add cast directly to bf16 f32 = f4_unpacked_to_f32(f4_unpacked) @@ -483,6 +484,7 @@ def __new__( torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8, + torch.float4_e2m1fn_x2, ), "unsupported" self.qdata = qdata self.scale = scale_e8m0_bits diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 8dbdc5ab15..8b2b05d38b 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -478,8 +478,8 @@ def _addmm_nvfp4_dispatch( # should_add_bias_separately = bias is not None result = torch._scaled_mm( - a.qdata.view(torch.float4_e2m1fn_x2), - b.qdata.view(torch.float4_e2m1fn_x2), + a.qdata, + b.qdata, a_scale_blocked.view(torch.float8_e4m3fn), b_scale_blocked.view(torch.float8_e4m3fn), bias=None if should_add_bias_separately else bias, @@ -685,7 +685,6 @@ def nvfp4_quantize( data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX) data_scaled = data_scaled.view(orig_shape) data_lp = f32_to_f4_unpacked(data_scaled) - # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' - # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2) data_lp = pack_uint4(data_lp) + data_lp = data_lp.view(torch.float4_e2m1fn_x2) return out_scales, data_lp