Skip to content

Commit c9fec32

Browse files
committed
Port ROCm changes from multi-backend-refactor branch
1 parent 5eb35ec commit c9fec32

10 files changed

Lines changed: 4654 additions & 49 deletions

File tree

CMakeLists.txt

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@ endif()
2525
# Define included source files
2626
set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
2727
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
28+
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
2829
set(MPS_FILES csrc/mps_ops.mm)
2930
set(METAL_FILES csrc/mps_kernels.metal)
3031
# C++ sources are always included
3132
list(APPEND SRC_FILES ${CPP_FILES})
3233

33-
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)")
34-
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps)
34+
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
35+
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
3536
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)
3637

3738
if(APPLE)
@@ -47,15 +48,25 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
4748
message(FATAL_ERROR "CUDA is not supported on macOS" )
4849
endif()
4950
set(BUILD_CUDA ON)
51+
set(BUILD_HIP OFF)
52+
set(BUILD_MPS OFF)
53+
elseif(${COMPUTE_BACKEND} STREQUAL "hip")
54+
if(APPLE)
55+
message(FATAL_ERROR "HIP is not supported on macOS" )
56+
endif()
57+
set(BUILD_CUDA OFF)
58+
set(BUILD_HIP ON)
5059
set(BUILD_MPS OFF)
5160
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
5261
if(NOT APPLE)
5362
message(FATAL_ERROR "MPS is only supported on macOS" )
5463
endif()
5564
set(BUILD_CUDA OFF)
65+
set(BUILD_HIP OFF)
5666
set(BUILD_MPS ON)
5767
else()
5868
set(BUILD_CUDA OFF)
69+
set(BUILD_HIP OFF)
5970
set(BUILD_MPS OFF)
6071
endif()
6172

@@ -160,6 +171,36 @@ if(BUILD_CUDA)
160171

161172
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
162173
add_compile_definitions(BUILD_CUDA)
174+
elseif(BUILD_HIP)
175+
enable_language(HIP)
176+
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
177+
if(DEFINED BNB_ROCM_ARCH)
178+
set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH})
179+
else()
180+
if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
181+
set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100")
182+
elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
183+
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
184+
endif()
185+
endif()
186+
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")
187+
188+
list(APPEND SRC_FILES ${HIP_FILES})
189+
190+
string(APPEND BNB_OUTPUT_NAME "_rocm")
191+
192+
# get hip version
193+
execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION)
194+
string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}")
195+
string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}")
196+
197+
string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}")
198+
if(HIP_VERSION VERSION_LESS "6.1")
199+
string(APPEND BNB_OUTPUT_NAME "_nohipblaslt")
200+
endif()
201+
add_compile_definitions(__HIP_PLATFORM_AMD__)
202+
add_compile_definitions(__HIP_PLATFORM_HCC__)
203+
add_compile_definitions(BUILD_HIP)
163204
elseif(BUILD_MPS)
164205
if(NOT APPLE)
165206
message(FATAL_ERROR "MPS is only supported on macOS" )
@@ -208,6 +249,41 @@ if(BUILD_CUDA)
208249
CUDA_SEPARABLE_COMPILATION ON
209250
)
210251
endif()
252+
if(BUILD_HIP)
253+
if(NOT DEFINED ENV{ROCM_PATH})
254+
set(ROCM_PATH /opt/rocm)
255+
else()
256+
set(ROCM_PATH $ENV{ROCM_PATH})
257+
endif()
258+
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
259+
macro(find_package_and_print_version PACKAGE_NAME)
260+
find_package("${PACKAGE_NAME}" ${ARGN})
261+
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
262+
endmacro()
263+
find_package_and_print_version(hipblas REQUIRED)
264+
find_package_and_print_version(hiprand REQUIRED)
265+
find_package_and_print_version(hipsparse REQUIRED)
266+
267+
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
268+
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
269+
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
270+
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
271+
272+
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
273+
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
274+
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
275+
276+
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
277+
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
278+
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
279+
280+
if(HIP_VERSION VERSION_LESS "6.1")
281+
target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT)
282+
else()
283+
find_package(hipblaslt)
284+
target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt)
285+
endif()
286+
endif()
211287
if(BUILD_MPS)
212288
add_dependencies(bitsandbytes metallib)
213289
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")

bitsandbytes/cextension.py

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,17 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
2222
"""
2323

2424
prefix = "rocm" if torch.version.hip else "cuda"
25-
library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
25+
blas_suffix = "_nohipblaslt" if torch.version.hip and cuda_specs.cuda_version_tuple < (6, 1) else ""
26+
library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{blas_suffix}{DYNAMIC_LIBRARY_SUFFIX}"
2627

2728
override_value = os.environ.get("BNB_CUDA_VERSION")
2829
if override_value:
2930
library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1)
31+
if torch.version.hip:
32+
raise RuntimeError(
33+
f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n"
34+
f"Clear the variable and retry: export BNB_CUDA_VERSION=\n"
35+
)
3036
logger.warning(
3137
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
3238
"This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n"
@@ -72,10 +78,11 @@ def __init__(self, lib: ct.CDLL):
7278

7379
def get_available_cuda_binary_versions() -> list[str]:
7480
"""Get formatted CUDA versions from existing library files using cuda_specs logic"""
75-
lib_pattern = f"libbitsandbytes_cuda*{DYNAMIC_LIBRARY_SUFFIX}"
81+
lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}"
7682
versions = []
7783
for lib in Path(__file__).parent.glob(lib_pattern):
78-
match = re.search(r"cuda(\d{3})", lib.name)
84+
pattern = r"{}(\d+)".format(BNB_BACKEND.lower())
85+
match = re.search(pattern, lib.name)
7986
if match:
8087
ver_code = int(match.group(1))
8188
major = ver_code // 10
@@ -86,8 +93,8 @@ def get_available_cuda_binary_versions() -> list[str]:
8693

8794
def parse_cuda_version(version_str: str) -> str:
8895
"""Convert raw version string (e.g. '118' from env var) to formatted version (e.g. '11.8')"""
89-
if version_str.isdigit() and len(version_str) == 3:
90-
return f"{version_str[:2]}.{version_str[2]}"
96+
if version_str.isdigit():
97+
return f"{version_str[:-1]}.{version_str[-1]}"
9198
return version_str # fallback as safety net
9299

93100

@@ -148,7 +155,7 @@ def _format_lib_error_message(
148155
"""Format detailed error message for library loading failures"""
149156
analysis = ""
150157
no_cpu_lib_found = "libbitsandbytes_cpu.so: cannot open" in original_error
151-
no_cuda_lib_found = "CUDA binary not found" in original_error
158+
no_cuda_lib_found = f"{BNB_BACKEND} binary not found" in original_error
152159

153160
if no_cpu_lib_found:
154161
analysis = "\n🚨 Failed to load CPU-only bitsandbytes library 🚨\n\n"
@@ -157,9 +164,9 @@ def _format_lib_error_message(
157164
version_list_str = "\n - " + "\n - ".join(available_versions) if available_versions else "NONE"
158165
analysis = (
159166
(
160-
f"\n🚨 CUDA VERSION MISMATCH 🚨\n"
161-
f"Requested CUDA version: {requested_version}\n"
162-
f"Detected PyTorch CUDA version: {user_cuda_version}\n"
167+
f"\n🚨 {BNB_BACKEND} VERSION MISMATCH 🚨\n"
168+
f"Requested {BNB_BACKEND} version: {requested_version}\n"
169+
f"Detected PyTorch {BNB_BACKEND} version: {user_cuda_version}\n"
163170
f"Available pre-compiled versions: {version_list_str}\n\n"
164171
"This means:\n"
165172
"The version you're trying to use is NOT distributed with this package\n\n"
@@ -174,42 +181,49 @@ def _format_lib_error_message(
174181

175182
troubleshooting = (
176183
(
177-
"This typically happens when:\n"
178-
"1. bitsandbytes doesn't ship with a pre-compiled binary for your CUDA version\n"
179-
"2. The library wasn't compiled properly during installation from source\n\n"
184+
f"This typically happens when:\n"
185+
f"1. bitsandbytes doesn't ship with a pre-compiled binary for your {BNB_BACKEND} version\n"
186+
f"2. The library wasn't compiled properly during installation from source\n\n"
180187
)
181188
if no_cuda_lib_found
182-
else "This typically happens when you checked the code out from source and your torch installation doesn't detect CUDA on your machine.\n\n"
189+
else f"This typically happens when you checked the code out from source and your torch installation doesn't detect {BNB_BACKEND} on your machine.\n\n"
183190
)
184191

185192
note = (
186193
(
187-
"To make bitsandbytes work, the compiled library version MUST exactly match the linked CUDA version.\n"
188-
"If your CUDA version doesn't have a pre-compiled binary, you MUST compile from source.\n\n"
194+
f"To make bitsandbytes work, the compiled library version MUST exactly match the linked {BNB_BACKEND} version.\n"
195+
f"If your {BNB_BACKEND} version doesn't have a pre-compiled binary, you MUST compile from source.\n\n"
189196
)
190197
if no_cuda_lib_found
191198
else ""
192199
)
193200

194201
compile_instructions = (
202+
(
203+
"COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n"
204+
) if not no_cuda_lib_found
205+
else
195206
(
196207
"You have two options:\n"
197208
"1. COMPILE FROM SOURCE (required if no binary exists):\n"
198209
" https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n"
199210
"2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n"
211+
) if not HIP_ENVIRONMENT
212+
else
213+
(
214+
"You can COMPILE FROM SOURCE as mentioned here:\n"
215+
" https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n"
200216
)
201-
if no_cuda_lib_found
202-
else "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n"
203217
)
204218

205219
diagnostics = (
206-
"🔍 Run this command for detailed diagnostics:\n"
207-
"python -m bitsandbytes\n\n"
208-
"If you've tried everything and still have issues:\n"
209-
"1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n"
210-
"2. Describe what you've tried in detail\n"
211-
"3. Open an issue with this information:\n"
212-
" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n"
220+
f"🔍 Run this command for detailed diagnostics:\n"
221+
f"python -m bitsandbytes\n\n"
222+
f"If you've tried everything and still have issues:\n"
223+
f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n"
224+
f"2. Describe what you've tried in detail\n"
225+
f"3. Open an issue with this information:\n"
226+
f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n"
213227
)
214228

215229
return f"{analysis}{base_msg}{troubleshooting}{note}{compile_instructions}{original_error}\n{diagnostics}"
@@ -224,18 +238,19 @@ def _format_dependency_error(self) -> str:
224238
)
225239

226240
return (
227-
f"\n🚨 CUDA SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n"
228-
f"CUDA {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n"
241+
f"\n🚨 {BNB_BACKEND} SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n"
242+
f"{BNB_BACKEND} {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n"
229243
f"To fix this, make sure that:\n"
230-
f"1. You have installed CUDA {cuda_major_version}.x toolkit on your system\n"
231-
f"2. The CUDA runtime libraries are in your LD_LIBRARY_PATH\n\n"
244+
f"1. You have installed {BNB_BACKEND} {cuda_major_version}.x toolkit on your system\n"
245+
f"2. The {BNB_BACKEND} runtime libraries are in your LD_LIBRARY_PATH\n\n"
232246
f"You can add them with (and persist the change by adding the line to your .bashrc):\n"
233-
f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/cuda-{cuda_major_version}.x/lib64\n\n"
247+
f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/{BNB_BACKEND.lower()}-{cuda_major_version}.x/\
248+
{'lib64' if not HIP_ENVIRONMENT else 'lib'}\n\n"
234249
f"Original error: {self.error_msg}\n\n"
235250
f"🔍 Run this command for detailed diagnostics:\n"
236251
f"python -m bitsandbytes\n\n"
237252
f"If you've tried everything and still have issues:\n"
238-
f"1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n"
253+
f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n"
239254
f"2. Describe what you've tried in detail\n"
240255
f"3. Open an issue with this information:\n"
241256
f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n"
@@ -264,7 +279,7 @@ def get_native_library() -> BNBNativeLibrary:
264279
cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)
265280

266281
if not cuda_binary_path.exists():
267-
raise RuntimeError(f"Configured CUDA binary not found at {cuda_binary_path}")
282+
raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}")
268283

269284
binary_path = cuda_binary_path
270285

@@ -284,6 +299,11 @@ def get_native_library() -> BNBNativeLibrary:
284299

285300

286301
try:
302+
if torch.version.hip:
303+
HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm"
304+
else:
305+
HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA"
306+
287307
lib = get_native_library()
288308
except Exception as e:
289309
error_msg = str(e)

0 commit comments

Comments
 (0)