Skip to content

Commit 359d545

Browse files
committed
Skip unsupported tests on ROCm
1 parent 0507a45 commit 359d545

3 files changed

Lines changed: 5 additions & 4 deletions

File tree

tests/test_functional.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
463463
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
464464
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
465465
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
466+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
466467
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
467468
def min_max(x):
468469
maxA = torch.amax(x, dim=2, keepdim=True)
@@ -1110,6 +1111,7 @@ class TestQuantize4BitFunctional:
11101111
"blocksize",
11111112
[64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096],
11121113
)
1114+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
11131115
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11141116
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
11151117
pytest.skip("This configuration is not supported on HPU.")
@@ -1408,10 +1410,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
14081410
@pytest.mark.parametrize("device", get_available_devices())
14091411
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
14101412
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
1411-
@pytest.mark.skipif(
1412-
HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
1413-
reason="this test is not supported on ROCm with gfx90a architecture yet",
1414-
)
1413+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
14151414
def test_gemv_eye_4bit(self, device, storage_type, dtype):
14161415
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
14171416
pytest.skip("This configuration is not supported on HPU.")

tests/test_linear8bitlt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def test_linear8bit_serialization(linear8bit):
233233
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
234234
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
235235
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
236+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
236237
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
237238
if device == "cuda" and platform.system() == "Windows":
238239
pytest.skip("Triton is not officially supported on Windows")

tests/test_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
211211
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
212212
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
213213
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
214+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
214215
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
215216
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
216217
pytest.skip("This configuration is not supported on HPU.")

0 commit comments

Comments
 (0)