Skip to content

Commit 79fc632

Browse files
authored
Merge pull request #65 from MISHANMAURYA/upstream_main_rocm_enabled
Generalize Ops and Functional
2 parents c9fec32 + b96905d commit 79fc632

9 files changed

Lines changed: 107 additions & 32 deletions

File tree

bitsandbytes/backends/cuda/ops.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
99

1010
from ..._ops import register_kernel
11-
from ...cextension import lib
11+
from ...cextension import lib, HIP_ENVIRONMENT
1212

1313

1414
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
@@ -210,7 +210,12 @@ def _get_col_absmax(
210210
@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
211211
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
212212
torch._check_is_size(blocksize)
213-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
213+
214+
if HIP_ENVIRONMENT:
215+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
216+
else:
217+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
218+
214219
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
215220

216221
n = A.numel()
@@ -264,7 +269,11 @@ def _(
264269
def _dequantize_blockwise_impl(
265270
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
266271
) -> None:
267-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
272+
if HIP_ENVIRONMENT:
273+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
274+
else:
275+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
276+
268277
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
269278
torch._check(
270279
dtype in [torch.float16, torch.bfloat16, torch.float32],
@@ -294,7 +303,11 @@ def _dequantize_blockwise_impl(
294303
def _(
295304
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
296305
) -> tuple[torch.Tensor, torch.Tensor]:
297-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
306+
if HIP_ENVIRONMENT:
307+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
308+
else:
309+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
310+
298311
torch._check(quant_type in ["fp4", "nf4"])
299312
torch._check(
300313
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
@@ -372,7 +385,11 @@ def _dequantize_4bit_impl(
372385
dtype: torch.dtype,
373386
out: torch.Tensor,
374387
) -> None:
375-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
388+
if HIP_ENVIRONMENT:
389+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
390+
else:
391+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
392+
376393
torch._check(quant_type in ["fp4", "nf4"])
377394
torch._check(
378395
dtype in [torch.bfloat16, torch.float16, torch.float32],

bitsandbytes/cextension.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99

1010
from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
11-
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple
11+
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -298,6 +298,8 @@ def get_native_library() -> BNBNativeLibrary:
298298
return BNBNativeLibrary(dll)
299299

300300

301+
ROCM_GPU_ARCH = get_rocm_gpu_arch()
302+
301303
try:
302304
if torch.version.hip:
303305
HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm"

bitsandbytes/cuda_specs.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import dataclasses
2+
import logging
3+
import re
4+
import subprocess
25
from functools import lru_cache
36
from typing import Optional
47

@@ -73,3 +76,27 @@ def get_cuda_specs() -> Optional[CUDASpecs]:
7376
)
7477
except Exception:
7578
return None
79+
80+
81+
def get_rocm_gpu_arch() -> str:
82+
"""Get ROCm GPU architecture."""
83+
logger = logging.getLogger(__name__)
84+
try:
85+
if torch.version.hip:
86+
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
87+
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
88+
if match:
89+
return "gfx" + match.group(1)
90+
else:
91+
return "unknown"
92+
else:
93+
return "unknown"
94+
except Exception as e:
95+
logger.error(f"Could not detect ROCm GPU architecture: {e}")
96+
if torch.cuda.is_available():
97+
logger.warning(
98+
"""
99+
ROCm GPU architecture detection failed despite ROCm being available.
100+
""",
101+
)
102+
return "unknown"

bitsandbytes/functional.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
1717

18-
from .cextension import lib
18+
from .cextension import lib, HIP_ENVIRONMENT
1919

2020
name2qmap = {}
2121

@@ -953,29 +953,33 @@ def quantize_fp4(
953953
A: torch.Tensor,
954954
absmax: Optional[torch.Tensor] = None,
955955
out: Optional[torch.Tensor] = None,
956-
blocksize=64,
956+
blocksize=None,
957957
compress_statistics=False,
958958
quant_storage=torch.uint8,
959959
):
960+
if blocksize is None:
961+
blocksize = 64 if not HIP_ENVIRONMENT else 128
960962
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)
961963

962964

963965
def quantize_nf4(
964966
A: torch.Tensor,
965967
absmax: Optional[torch.Tensor] = None,
966968
out: Optional[torch.Tensor] = None,
967-
blocksize=64,
969+
blocksize=None,
968970
compress_statistics=False,
969971
quant_storage=torch.uint8,
970972
):
973+
if blocksize is None:
974+
blocksize = 64 if not HIP_ENVIRONMENT else 128
971975
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)
972976

973977

974978
def quantize_4bit(
975979
A: torch.Tensor,
976980
absmax: Optional[torch.Tensor] = None,
977981
out: Optional[torch.Tensor] = None,
978-
blocksize=64,
982+
blocksize=None,
979983
compress_statistics=False,
980984
quant_type="fp4",
981985
quant_storage=torch.uint8,
@@ -1003,6 +1007,10 @@ def quantize_4bit(
10031007
- `torch.Tensor`: The quantized tensor with packed 4-bit values.
10041008
- [`QuantState`]: The state object used to undo the quantization.
10051009
"""
1010+
1011+
if blocksize is None:
1012+
blocksize = 64 if not HIP_ENVIRONMENT else 128
1013+
10061014
input_shape = A.shape
10071015

10081016
_out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default(
@@ -1053,8 +1061,10 @@ def dequantize_fp4(
10531061
quant_state: Optional[QuantState] = None,
10541062
absmax: Optional[torch.Tensor] = None,
10551063
out: Optional[torch.Tensor] = None,
1056-
blocksize: int = 64,
1064+
blocksize: Optional[int] = None,
10571065
) -> torch.Tensor:
1066+
if blocksize is None:
1067+
blocksize = 64 if not HIP_ENVIRONMENT else 128
10581068
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
10591069

10601070

@@ -1063,8 +1073,10 @@ def dequantize_nf4(
10631073
quant_state: Optional[QuantState] = None,
10641074
absmax: Optional[torch.Tensor] = None,
10651075
out: Optional[torch.Tensor] = None,
1066-
blocksize: int = 64,
1076+
blocksize: Optional[int] = None,
10671077
) -> torch.Tensor:
1078+
if blocksize is None:
1079+
blocksize = 64 if not HIP_ENVIRONMENT else 128
10681080
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
10691081

10701082

@@ -1073,7 +1085,7 @@ def dequantize_4bit(
10731085
quant_state: Optional[QuantState] = None,
10741086
absmax: Optional[torch.Tensor] = None,
10751087
out: Optional[torch.Tensor] = None,
1076-
blocksize: int = 64,
1088+
blocksize: Optional[int] = None,
10771089
quant_type="fp4",
10781090
) -> torch.Tensor:
10791091
"""Dequantizes a packed 4-bit quantized tensor.
@@ -1102,6 +1114,10 @@ def dequantize_4bit(
11021114
Returns:
11031115
`torch.Tensor`: The dequantized tensor.
11041116
"""
1117+
1118+
if blocksize is None:
1119+
blocksize = 64 if not HIP_ENVIRONMENT else 128
1120+
11051121
if quant_state is None:
11061122
assert absmax is not None and out is not None
11071123

bitsandbytes/nn/modules.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn.functional as F
1212

1313
import bitsandbytes as bnb
14+
from bitsandbytes.cextension import HIP_ENVIRONMENT
1415
from bitsandbytes.functional import QuantState
1516
from bitsandbytes.optim import GlobalOptimManager
1617
from bitsandbytes.utils import (
@@ -212,7 +213,7 @@ def __new__(
212213
data: Optional[torch.Tensor] = None,
213214
requires_grad=False, # quantized weights should be frozen by default
214215
quant_state: Optional[QuantState] = None,
215-
blocksize: int = 64,
216+
blocksize: Optional[int] = None,
216217
compress_statistics: bool = True,
217218
quant_type: str = "fp4",
218219
quant_storage: torch.dtype = torch.uint8,
@@ -221,7 +222,10 @@ def __new__(
221222
) -> "Params4bit":
222223
if data is None:
223224
data = torch.empty(0)
224-
225+
226+
if blocksize is None:
227+
blocksize = 64 if not HIP_ENVIRONMENT else 128
228+
225229
self = torch.Tensor._make_subclass(cls, data, requires_grad)
226230
self.blocksize = blocksize
227231
self.compress_statistics = compress_statistics

tests/test_cuda_setup_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from bitsandbytes.cextension import get_cuda_bnb_library_path
3+
from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
44
from bitsandbytes.cuda_specs import CUDASpecs
55

66

@@ -12,12 +12,12 @@ def cuda120_spec() -> CUDASpecs:
1212
cuda_version_tuple=(12, 0),
1313
)
1414

15-
15+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
1616
def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
1717
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
1818
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"
1919

20-
20+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
2121
def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
2222
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
2323
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"

tests/test_functional.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99

1010
import bitsandbytes as bnb
11+
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH
1112
from bitsandbytes import functional as F
1213
from tests.helpers import (
1314
BOOLEAN_TUPLES,
@@ -91,7 +92,7 @@ class Test8BitBlockwiseQuantizeFunctional:
9192
@pytest.mark.parametrize("device", get_available_devices())
9293
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
9394
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
94-
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
95+
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128] )
9596
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
9697
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
9798
iters = 100
@@ -795,7 +796,7 @@ def test_coo_int8_vectorwise_quant(self, device, dim1, dim2):
795796
A[:, outlier_cols] = 0
796797
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
797798

798-
799+
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
799800
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
800801
class TestSpMMFunctional:
801802
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
@@ -1105,7 +1106,7 @@ class TestQuantize4BitFunctional:
11051106
@pytest.mark.parametrize("device", get_available_devices())
11061107
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
11071108
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1108-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
1109+
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096])
11091110
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11101111
if device == "cpu" and quant_type != "nf4":
11111112
pytest.xfail("fp4 quantization is not supported on CPU")
@@ -1140,7 +1141,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11401141

11411142
@pytest.mark.parametrize("device", get_available_devices())
11421143
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1143-
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize"))
1144+
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize"))
11441145
def test_4bit_compressed_stats(self, device, quant_type, blocksize):
11451146
if device == "cpu" and quant_type != "nf4":
11461147
pytest.xfail("fp4 quantization is not supported on CPU")
@@ -1204,7 +1205,10 @@ def test_bench_4bit_dequant(self, quant_type):
12041205
# torch.matmul(b, a.t())
12051206
# torch.cuda.synchronize()
12061207
# print((time.time()-t0)/iters*1e6)
1207-
1208+
1209+
@pytest.mark.skipif(
1210+
HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64"
1211+
)
12081212
@pytest.mark.parametrize("device", get_available_devices())
12091213
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
12101214
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@@ -1368,6 +1372,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
13681372
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
13691373
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
13701374
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
1375+
@pytest.mark.skipif(
1376+
HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
1377+
reason="this test is not supported on ROCm with gfx90a architecture yet",
1378+
)
13711379
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
13721380
if device == "cpu" and storage_type != "nf4":
13731381
pytest.xfail("fp4 quantization is not supported on CPU")

tests/test_linear4bit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
import bitsandbytes as bnb
10+
from bitsandbytes.cextension import HIP_ENVIRONMENT
1011
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer
1112

1213
storage = {
@@ -16,7 +17,6 @@
1617
"float32": torch.float32,
1718
}
1819

19-
2020
@pytest.mark.parametrize("device", get_available_devices())
2121
@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"])
2222
@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias"))
@@ -183,7 +183,7 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua
183183

184184
@pytest.mark.parametrize("device", get_available_devices())
185185
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
186-
@pytest.mark.parametrize("blocksize", [64, 128])
186+
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
187187
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
188188
def test_copy_param(device, quant_type, blocksize, compress_statistics):
189189
if device == "cpu":
@@ -208,7 +208,7 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
208208

209209
@pytest.mark.parametrize("device", get_available_devices())
210210
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
211-
@pytest.mark.parametrize("blocksize", [64, 128])
211+
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
212212
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
213213
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
214214
if device == "cpu":
@@ -240,7 +240,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
240240

241241
@pytest.mark.parametrize("device", get_available_devices())
242242
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
243-
@pytest.mark.parametrize("blocksize", [64, 128])
243+
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
244244
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
245245
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
246246
if device == "cpu":

0 commit comments

Comments
 (0)