Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
prev_dims, K = data_hp.shape[:-1], data_hp.shape[-1]
if elem_dtype is torch.float4_e2m1fn_x2:
assert data_mx.qdata.shape == (*prev_dims, K // 2)
assert data_mx.qdata.dtype == torch.float4_e2m1fn_x2
else:
assert data_mx.qdata.shape == (*prev_dims, K)
assert data_mx.scale.shape == (*prev_dims, K // block_size)
Expand Down
18 changes: 16 additions & 2 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
)

torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
pt_unpacked = unpack_uint4(nvfp4_pt.qdata.view(torch.uint8))
triton_unpacked = unpack_uint4(nvfp4_triton.qdata.view(torch.uint8))
torch.testing.assert_close(
pt_unpacked,
triton_unpacked,
Expand Down Expand Up @@ -611,3 +611,17 @@ def test_3d_transpose(dims, is_swizzled_scales):
x_hp_t = x_hp.transpose(dims[0], dims[1])
x_nvfp4_t = x_nvfp4.transpose(dims[0], dims[1])
assert x_hp_t.shape == x_nvfp4_t.shape


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
@pytest.mark.parametrize("use_triton_kernel", [False, True])
def test_uses_fp4_qdata(use_triton_kernel):
x_hp = torch.randn(2, 128, 256, device="cuda")
# TODO also test triton kernel
x_nvfp4 = NVFP4Tensor.to_nvfp4(
x_hp, use_triton_kernel=use_triton_kernel, is_swizzled_scales=True
)
assert x_nvfp4.qdata.dtype == torch.float4_e2m1fn_x2
5 changes: 3 additions & 2 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,8 +1028,9 @@ def triton_quantize_nvfp4(
# reshape back to original shape
scales = scales.view(*orig_leading_dims, -1, padded_cols)
xq = xq.view(*orig_leading_dims, -1, N // 2)
xq = xq.view(torch.float4_e2m1fn_x2)

return scales, xq.view(torch.uint8)
return scales, xq

@triton_quantize_nvfp4.register_fake
def _(x, per_tensor_scale=None):
Expand All @@ -1043,7 +1044,7 @@ def _(x, per_tensor_scale=None):
scales = torch.empty(
padded_rows, padded_cols, device=x.device, dtype=torch.float8_e4m3fn
)
xq = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8)
xq = torch.empty(M, N // 2, device=x.device, dtype=torch.float4_e2m1fn_x2)
return scales, xq

@triton_mx_block_rearrange.register_fake
Expand Down
4 changes: 3 additions & 1 deletion torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def to_mx(
data_lp = data_lp.reshape(orig_shape)
data_lp = f32_to_f4_unpacked(data_lp)
data_lp = pack_uint4(data_lp)
data_lp = data_lp.view(torch.float4_e2m1fn_x2)
else:
raise AssertionError("unsupported")

Expand Down Expand Up @@ -382,7 +383,7 @@ def to_dtype(
data_hp = data_hp.to(target_dtype).reshape(orig_shape)
elif elem_dtype == torch.float4_e2m1fn_x2:
# fp4
f4_unpacked = unpack_uint4(data_lp)
f4_unpacked = unpack_uint4(data_lp.view(torch.uint8))
# for now we only have a cast to f32
# TODO(future PR): add cast directly to bf16
f32 = f4_unpacked_to_f32(f4_unpacked)
Expand Down Expand Up @@ -483,6 +484,7 @@ def __new__(
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint8,
torch.float4_e2m1fn_x2,
), "unsupported"
self.qdata = qdata
self.scale = scale_e8m0_bits
Expand Down
7 changes: 3 additions & 4 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,8 @@ def _addmm_nvfp4_dispatch(
# should_add_bias_separately = bias is not None

result = torch._scaled_mm(
a.qdata.view(torch.float4_e2m1fn_x2),
b.qdata.view(torch.float4_e2m1fn_x2),
a.qdata,
b.qdata,
a_scale_blocked.view(torch.float8_e4m3fn),
b_scale_blocked.view(torch.float8_e4m3fn),
bias=None if should_add_bias_separately else bias,
Expand Down Expand Up @@ -685,7 +685,6 @@ def nvfp4_quantize(
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
data_scaled = data_scaled.view(orig_shape)
data_lp = f32_to_f4_unpacked(data_scaled)
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
data_lp = pack_uint4(data_lp)
data_lp = data_lp.view(torch.float4_e2m1fn_x2)
return out_scales, data_lp
Loading