|
8 | 8 | import torch |
9 | 9 |
|
10 | 10 | import bitsandbytes as bnb |
| 11 | +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH |
11 | 12 | from bitsandbytes import functional as F |
12 | 13 | from tests.helpers import ( |
13 | 14 | BOOLEAN_TUPLES, |
@@ -91,7 +92,7 @@ class Test8BitBlockwiseQuantizeFunctional: |
91 | 92 | @pytest.mark.parametrize("device", get_available_devices()) |
92 | 93 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) |
93 | 94 | @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) |
94 | | - @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) |
| 95 | + @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128] ) |
95 | 96 | @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) |
96 | 97 | def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): |
97 | 98 | iters = 100 |
@@ -795,7 +796,7 @@ def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): |
795 | 796 | A[:, outlier_cols] = 0 |
796 | 797 | torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) |
797 | 798 |
|
798 | | - |
| 799 | +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") |
799 | 800 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") |
800 | 801 | class TestSpMMFunctional: |
801 | 802 | @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) |
@@ -1105,7 +1106,7 @@ class TestQuantize4BitFunctional: |
1105 | 1106 | @pytest.mark.parametrize("device", get_available_devices()) |
1106 | 1107 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) |
1107 | 1108 | @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) |
1108 | | - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) |
| 1109 | + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096]) |
1109 | 1110 | def test_4bit_quant(self, device, dtype, quant_type, blocksize): |
1110 | 1111 | if device == "cpu" and quant_type != "nf4": |
1111 | 1112 | pytest.xfail("fp4 quantization is not supported on CPU") |
@@ -1140,7 +1141,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): |
1140 | 1141 |
|
1141 | 1142 | @pytest.mark.parametrize("device", get_available_devices()) |
1142 | 1143 | @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) |
1143 | | - @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize")) |
| 1144 | + @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) |
1144 | 1145 | def test_4bit_compressed_stats(self, device, quant_type, blocksize): |
1145 | 1146 | if device == "cpu" and quant_type != "nf4": |
1146 | 1147 | pytest.xfail("fp4 quantization is not supported on CPU") |
@@ -1204,7 +1205,10 @@ def test_bench_4bit_dequant(self, quant_type): |
1204 | 1205 | # torch.matmul(b, a.t()) |
1205 | 1206 | # torch.cuda.synchronize() |
1206 | 1207 | # print((time.time()-t0)/iters*1e6) |
1207 | | - |
| 1208 | + |
| 1209 | + @pytest.mark.skipif( |
| 1210 | + HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" |
| 1211 | + ) |
1208 | 1212 | @pytest.mark.parametrize("device", get_available_devices()) |
1209 | 1213 | @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") |
1210 | 1214 | @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) |
@@ -1368,6 +1372,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double |
1368 | 1372 | @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) |
1369 | 1373 | @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) |
1370 | 1374 | @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) |
| 1375 | + @pytest.mark.skipif( |
| 1376 | + HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", |
| 1377 | + reason="this test is not supported on ROCm with gfx90a architecture yet", |
| 1378 | + ) |
1371 | 1379 | def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): |
1372 | 1380 | if device == "cpu" and storage_type != "nf4": |
1373 | 1381 | pytest.xfail("fp4 quantization is not supported on CPU") |
|
0 commit comments