Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tests/pytorch/test_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,86 @@ def test_chunk(
tols = dict(rtol=0, atol=0) # Chunking is exact
y_test = y_test.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)


@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
class TestMXFP8Tensor:

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("dims", [[128, 128], [256, 256], [128, 256]])
def test_mxfp8_dequantize_columnwise_only(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
dims: DimsType,
) -> None:
"""Check dequantization of MXFP8 tensor with only columnwise data"""

# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1

# Quantize with both rowwise and columnwise
quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True)
x_mxfp8 = quantizer(x_ref)

# Dequantize from rowwise (default path)
x_deq_rowwise = x_mxfp8.dequantize(dtype=dtype)

# Rowwise dequantization should be close to the original
torch.testing.assert_close(x_deq_rowwise, x_ref, **_tols[fp8_dtype])

# Strip rowwise data, keeping only columnwise
x_mxfp8.update_usage(rowwise_usage=False, columnwise_usage=True)
assert x_mxfp8._rowwise_data is None
assert x_mxfp8._columnwise_data is not None

# Dequantize from columnwise only
x_deq_columnwise = x_mxfp8.dequantize(dtype=dtype)

# Columnwise dequantization should be close to the original
torch.testing.assert_close(x_deq_columnwise, x_ref, **_tols[fp8_dtype])

# Rowwise and columnwise dequantizations should match each other
torch.testing.assert_close(x_deq_columnwise, x_deq_rowwise, **_tols[fp8_dtype])

# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_deq_columnwise, -x_ref, **_tols[fp8_dtype])

@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dims", [[128, 128], [256, 256]])
def test_mxfp8_dequantize_columnwise_only_quantized_separately(
self,
fp8_dtype: tex.DType,
dims: DimsType,
) -> None:
"""Check dequantization of MXFP8 tensor quantized with columnwise only"""

dtype = torch.bfloat16

# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1

# Quantize with columnwise only (no rowwise)
quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=False, columnwise=True)
x_mxfp8 = quantizer(x_ref)
assert x_mxfp8._rowwise_data is None
assert x_mxfp8._columnwise_data is not None

# Dequantize from columnwise only
x_deq = x_mxfp8.dequantize(dtype=dtype)

# Should be close to the original
torch.testing.assert_close(x_deq, x_ref, **_tols[fp8_dtype])

# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_deq, -x_ref, **_tols[fp8_dtype])
Comment on lines +707 to +741
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limited dtype and dimension coverage in second test

test_mxfp8_dequantize_columnwise_only_quantized_separately hardcodes dtype = torch.bfloat16 and omits the asymmetric [128, 256] dimension, while the companion test test_mxfp8_dequantize_columnwise_only is fully parameterized over _dtypes (float32, float16, bfloat16) and includes [128, 256].

If the columnwise-only quantization path genuinely cannot handle float32/float16 inputs or non-square tensors, that constraint is invisible from the test and should be documented. If it can, the missing coverage means regressions in those cases could go undetected.

Consider either:

  • Parameterizing the second test the same way as the first (adding @pytest.mark.parametrize("dtype", _dtypes) and [128, 256] to dims), or
  • Adding a comment explaining why those combinations are intentionally excluded.
Suggested change
torch.testing.assert_close(x_deq_columnwise, x_deq_rowwise, **_tols[fp8_dtype])
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_deq_columnwise, -x_ref, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dims", [[128, 128], [256, 256]])
def test_mxfp8_dequantize_columnwise_only_quantized_separately(
self,
fp8_dtype: tex.DType,
dims: DimsType,
) -> None:
"""Check dequantization of MXFP8 tensor quantized with columnwise only"""
dtype = torch.bfloat16
# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cuda") - 1
# Quantize with columnwise only (no rowwise)
quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=False, columnwise=True)
x_mxfp8 = quantizer(x_ref)
assert x_mxfp8._rowwise_data is None
assert x_mxfp8._columnwise_data is not None
# Dequantize from columnwise only
x_deq = x_mxfp8.dequantize(dtype=dtype)
# Should be close to the original
torch.testing.assert_close(x_deq, x_ref, **_tols[fp8_dtype])
# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_deq, -x_ref, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("dims", [[128, 128], [256, 256], [128, 256]])
def test_mxfp8_dequantize_columnwise_only_quantized_separately(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
dims: DimsType,
) -> None:
"""Check dequantization of MXFP8 tensor quantized with columnwise only"""

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def forward(
dtype = torch_to_transformer_engine_dtype[dtype]

# Make sure FP8 data is in expected format
if tensor._rowwise_data is not None:
if tensor._rowwise_data is not None or tensor._columnwise_data is not None:
return tex.dequantize(tensor, dtype)
raise NotImplementedError("Casting back from the transpose not implemented yet!")
raise ValueError("Cannot dequantize MXFP8 tensor with no data")

@staticmethod
def backward(
Expand Down
Loading