diff --git a/.github/scripts/build-rocm.sh b/.github/scripts/build-rocm.sh index d0efa5bb3..f88c4fc26 100644 --- a/.github/scripts/build-rocm.sh +++ b/.github/scripts/build-rocm.sh @@ -19,7 +19,7 @@ if [ "${build_os:0:6}" == ubuntu ]; then -w /src -v "$PWD:/src" "$image" sh -c \ "apt-get update \ && pip install cmake==3.31.6 \ - && cmake -DCOMPUTE_BACKEND=hip -DCMAKE_BUILD_TYPE=MinSizeRel -DCMAKE_HIP_FLAGS=\"--offload-compress\" -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \ + && cmake -DCOMPUTE_BACKEND=rocm -DCMAKE_BUILD_TYPE=MinSizeRel -DCMAKE_HIP_FLAGS=\"--offload-compress\" -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \ && cmake --build ." fi diff --git a/CMakeLists.txt b/CMakeLists.txt index 81326ffcf..54d8cc2f0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ # For GCC: `cmake -B build . && cmake --build build` # For MSVC: `cmake -B build . && cmake --build build --config Release` # You can also use the following options and variables -# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend +# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `rocm` or `mps` to select the backend # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version # is whatever CMake finds on your path. # - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. @@ -16,8 +16,8 @@ # libbitsandbytes_rocm70.so even if the system has ROCm 7.2. cmake_minimum_required(VERSION 3.22.1) -# On Windows with HIP backend, auto-detect compilers from ROCM_PATH before project() -if(WIN32 AND COMPUTE_BACKEND STREQUAL "hip") +# On Windows with ROCm backend, auto-detect compilers from ROCM_PATH before project() +if(WIN32 AND COMPUTE_BACKEND STREQUAL "rocm") if(DEFINED ENV{ROCM_PATH}) set(ROCM_PATH $ENV{ROCM_PATH}) endif() @@ -61,8 +61,8 @@ set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, rocm, mps, xpu)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda rocm mps xpu) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -78,33 +78,33 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") message(FATAL_ERROR "CUDA is not supported on macOS" ) endif() set(BUILD_CUDA ON) - set(BUILD_HIP OFF) + set(BUILD_ROCM OFF) set(BUILD_MPS OFF) -elseif(${COMPUTE_BACKEND} STREQUAL "hip") +elseif(${COMPUTE_BACKEND} STREQUAL "rocm") if(APPLE) - message(FATAL_ERROR "HIP is not supported on macOS" ) + message(FATAL_ERROR "ROCm is not supported on macOS" ) endif() set(BUILD_CUDA OFF) - set(BUILD_HIP ON) + set(BUILD_ROCM ON) set(BUILD_MPS OFF) elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) endif() set(BUILD_CUDA OFF) - set(BUILD_HIP OFF) + set(BUILD_ROCM OFF) set(BUILD_MPS ON) elseif(${COMPUTE_BACKEND} STREQUAL "xpu") if(APPLE) message(FATAL_ERROR "XPU is not supported on macOS" ) endif() set(BUILD_CUDA OFF) - set(BUILD_HIP OFF) + set(BUILD_ROCM OFF) set(BUILD_MPS OFF) set(BUILD_XPU ON) else() set(BUILD_CUDA OFF) - set(BUILD_HIP OFF) + set(BUILD_ROCM OFF) set(BUILD_MPS OFF) set(BUILD_XPU OFF) set(BUILD_CPU ON) @@ -228,7 +228,7 @@ if(BUILD_CUDA) string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") add_compile_definitions(BUILD_CUDA) -elseif(BUILD_HIP) +elseif(BUILD_ROCM) # Set target architectures before enable_language(HIP), which would otherwise # auto-detect a single GPU and override the defaults. if(DEFINED BNB_ROCM_ARCH) @@ -247,23 +247,96 @@ elseif(BUILD_HIP) string(APPEND BNB_OUTPUT_NAME "_rocm") - # get hip version - execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION) - string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}") - string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}") + # Always initialize this so downstream version checks are deterministic. + set(_DETECTED_ROCM_VERSION "0.0") + set(_ROCM_VERSION_SHORT "") + + # Allow the user to skip all version detection by passing -DROCM_VERSION= + if(DEFINED ROCM_VERSION) + message(STATUS "ROCm Version: ${ROCM_VERSION} (user-supplied via -DROCM_VERSION)") + # Assume user-supplied ROCM_VERSION is a shortcode (e.g. 71). + set(_ROCM_VERSION_SHORT "${ROCM_VERSION}") + string(LENGTH "${ROCM_VERSION}" _ROCM_VERSION_LEN) + if(_ROCM_VERSION_LEN GREATER 1) + math(EXPR _ROCM_MAJOR_LEN "${_ROCM_VERSION_LEN} - 1") + string(SUBSTRING "${ROCM_VERSION}" 0 ${_ROCM_MAJOR_LEN} _ROCM_MAJOR) + string(SUBSTRING "${ROCM_VERSION}" ${_ROCM_MAJOR_LEN} 1 _ROCM_MINOR) + set(_DETECTED_ROCM_VERSION "${_ROCM_MAJOR}.${_ROCM_MINOR}") + else() + message(WARNING + "ROCM_VERSION='${ROCM_VERSION}' looks like a single digit. " + "Expected a two-digit shortcode (e.g. 71 for ROCm 7.1). " + "Interpreting as ${ROCM_VERSION}.0." + ) + set(_DETECTED_ROCM_VERSION "${ROCM_VERSION}.0") + endif() + else() + # Detect the actual ROCm version. + # Prefer the .info/version file (the canonical ROCm version) because the + # HIP SDK version diverged from the ROCm version starting with ROCm 7.x + set(_DETECTED_ROCM_VERSION "") + + # Resolve the ROCm installation root (same logic used later for find_package) + if(DEFINED ENV{ROCM_PATH}) + set(_ROCM_ROOT "$ENV{ROCM_PATH}") + else() + if(WIN32) + message(WARNING + "ROCM_PATH environment variable is not set. " + "On Windows this is the primary way to locate the ROCm installation.\n" + "Falling back to C:/opt/rocm. Set ROCM_PATH if ROCm is installed elsewhere." + ) + set(_ROCM_ROOT "C:/opt/rocm") + else() + set(_ROCM_ROOT "/opt/rocm") + endif() + endif() - # Expose a cache variable that the user can set to override the ROCm version in the library name - set(ROCM_VERSION "${HIP_VERSION_SHORT}" CACHE STRING "Expected ROCm Version Shortcode") + # Try /.info/version + if(_ROCM_ROOT AND EXISTS "${_ROCM_ROOT}/.info/version") + file(READ "${_ROCM_ROOT}/.info/version" _ROCM_INFO_CONTENT) + string(STRIP "${_ROCM_INFO_CONTENT}" _ROCM_INFO_CONTENT) + string(REGEX MATCH "[0-9]+\\.[0-9]+" _DETECTED_ROCM_VERSION "${_ROCM_INFO_CONTENT}") + if(_DETECTED_ROCM_VERSION) + message(STATUS "ROCm Version: ${_DETECTED_ROCM_VERSION} (from ${_ROCM_ROOT}/.info/version)") + endif() + endif() + + # Fall back to hipconfig --version (HIP SDK version) for older installs + if(NOT _DETECTED_ROCM_VERSION) + execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION) + string(REGEX MATCH "[0-9]+\\.[0-9]+" _DETECTED_ROCM_VERSION "${HIP_CONFIG_VERSION}") + + if(_DETECTED_ROCM_VERSION) + message(WARNING + "Could not read ROCm version from ${_ROCM_ROOT}/.info/version; " + "falling back to hipconfig (${_DETECTED_ROCM_VERSION}).\n" + "Starting with ROCm 7.x the HIP SDK version diverges from the ROCm version, " + "which may produce a misnamed library.\n" + "To fix this you can either:\n" + " - At build time: cmake -DROCM_VERSION= (e.g. -DROCM_VERSION=71 for ROCm 7.1)\n" + " - At runtime: export BNB_ROCM_VERSION= (e.g. BNB_ROCM_VERSION=71)" + ) + else() + message(FATAL_ERROR + "Could not detect the ROCm version.\n" + "Checked:\n" + " 1. ${_ROCM_ROOT}/.info/version — not found or not readable\n" + " 2. hipconfig --version — not found or returned no version\n" + "Please install ROCm and ensure ROCM_PATH is set correctly, or\n" + ) + endif() + endif() - message(STATUS "ROCm Version: ${HIP_VERSION_SHORT} (from hipconfig)") - if(NOT ROCM_VERSION STREQUAL "${HIP_VERSION_SHORT}") - message(WARNING "Overriding ROCm version in library name: ${HIP_VERSION_SHORT} -> ${ROCM_VERSION}") + string(REPLACE "." "" _DETECTED_ROCM_VERSION_SHORT "${_DETECTED_ROCM_VERSION}") + set(_ROCM_VERSION_SHORT "${_DETECTED_ROCM_VERSION_SHORT}") + set(ROCM_VERSION "${_DETECTED_ROCM_VERSION_SHORT}" CACHE STRING "Expected ROCm Version Shortcode") endif() - string(APPEND BNB_OUTPUT_NAME "${ROCM_VERSION}") + string(APPEND BNB_OUTPUT_NAME "${_ROCM_VERSION_SHORT}") add_compile_definitions(__HIP_PLATFORM_AMD__) add_compile_definitions(__HIP_PLATFORM_HCC__) - add_compile_definitions(BUILD_HIP) + add_compile_definitions(BUILD_ROCM) elseif(BUILD_MPS) if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) @@ -354,7 +427,7 @@ if(BUILD_CUDA) CUDA_SEPARABLE_COMPILATION ON ) endif() -if(BUILD_HIP) +if(BUILD_ROCM) # Determine ROCM_PATH from environment variable, fallback to /opt/rocm on Linux if(DEFINED ENV{ROCM_PATH}) set(ROCM_PATH $ENV{ROCM_PATH}) @@ -391,7 +464,7 @@ if(BUILD_HIP) set_source_files_properties(${GPU_FILES} PROPERTIES LANGUAGE HIP) set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) - if(HIP_VERSION VERSION_LESS "6.1") + if(_DETECTED_ROCM_VERSION VERSION_LESS "6.1") target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT) else() find_package(hipblaslt) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index c3cec7281..4d9115624 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -31,28 +31,34 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: prefix = "rocm" if torch.version.hip else "cuda" library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" - override_value = os.environ.get("BNB_CUDA_VERSION") + cuda_override_value = os.environ.get("BNB_CUDA_VERSION") rocm_override_value = os.environ.get("BNB_ROCM_VERSION") - if rocm_override_value and torch.version.hip: + if rocm_override_value: library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1) + if torch.version.cuda: + raise RuntimeError( + f"BNB_ROCM_VERSION={rocm_override_value} detected for CUDA!\n" + "Use BNB_CUDA_VERSION instead: export BNB_CUDA_VERSION=\n" + "Clear the variable and retry: unset BNB_ROCM_VERSION\n" + ) logger.warning( f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n" "This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n" - "If this was unintended set the BNB_ROCM_VERSION variable to an empty string: export BNB_ROCM_VERSION=\n" + "If this was unintended clear the variable and retry: unset BNB_ROCM_VERSION\n" ) - elif override_value: - library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) + elif cuda_override_value: + library_name = re.sub(r"cuda\d+", f"cuda{cuda_override_value}", library_name, count=1) if torch.version.hip: raise RuntimeError( - f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" + f"BNB_CUDA_VERSION={cuda_override_value} detected for ROCm!\n" f"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=\n" - f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" + f"Clear the variable and retry: unset BNB_CUDA_VERSION\n" ) logger.warning( - f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" + f"WARNING: BNB_CUDA_VERSION={cuda_override_value} environment variable detected; loading {library_name}.\n" "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" - "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" + "If this was unintended clear the variable and retry: unset BNB_CUDA_VERSION\n" ) return PACKAGE_DIR / library_name diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 2a1d06daa..a4505b4e3 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -26,12 +26,18 @@ def get_compute_capabilities() -> list[tuple[int, int]]: @lru_cache(None) def get_cuda_version_tuple() -> Optional[tuple[int, int]]: - """Get CUDA/HIP version as a tuple of (major, minor).""" + """Get CUDA/ROCm version as a tuple of (major, minor). + + For ROCm, prefers ``torch.version.rocm`` (the actual ROCm version) + over ``torch.version.hip`` (the HIP SDK version) because the two + version lines diverged starting with ROCm 7.x. Falls back to + ``torch.version.hip`` when the attribute is not yet available. + """ try: if torch.version.cuda: version_str = torch.version.cuda elif torch.version.hip: - version_str = torch.version.hip + version_str = getattr(torch.version, "rocm", None) or torch.version.hip else: return None @@ -44,7 +50,7 @@ def get_cuda_version_tuple() -> Optional[tuple[int, int]]: def get_cuda_version_string() -> Optional[str]: - """Get CUDA/HIP version as a string.""" + """Get CUDA/ROCm version as a compact string (e.g. ``"120"`` or ``"71"``).""" version_tuple = get_cuda_version_tuple() if version_tuple is None: return None diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index de4d036cb..87ce212d8 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -204,7 +204,7 @@ def _print_hip_runtime_diagnostics() -> None: f""" Found duplicate ROCm runtime files (see below). - We select the PyTorch default ROCm runtime, which is {torch.version.hip}, + We select the PyTorch default ROCm runtime, which is {getattr(torch.version, "rocm", None) or torch.version.hip}, but this might mismatch with the ROCm version that is needed for bitsandbytes. To override this behavior set the `BNB_ROCM_VERSION=` environmental variable. diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 74da662b6..cac3e05ea 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -61,6 +61,7 @@ def show_environment(): print(f"PyTorch: {torch.__version__}") print(f" CUDA: {torch.version.cuda or 'N/A'}") + print(f" ROCm: {getattr(torch.version, 'rocm', 'N/A') or 'N/A'}") print(f" HIP: {torch.version.hip or 'N/A'}") print(f" XPU: {getattr(torch.version, 'xpu', 'N/A') or 'N/A'}") diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 7493574f0..236fe89ff 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -7,7 +7,8 @@ #include #include #endif -#if BUILD_HIP +#if BUILD_ROCM +#include #include #endif #if BUILD_MPS @@ -19,7 +20,7 @@ #include // Compatibility between HIP/CUDA APIs -#if BUILD_HIP +#if BUILD_ROCM #define cudaStream_t hipStream_t #define __nv_bfloat16 hip_bfloat16 #define cublasLtHandle_t hipblasLtHandle_t @@ -38,7 +39,7 @@ // UNMANGLED CALLS //=================================================================================== -#if BUILD_CUDA || BUILD_HIP +#if BUILD_CUDA || BUILD_ROCM void gemm_4bit_inference_naive_fp16( int m, int n, int k, half* A, unsigned char* B, float* absmax, float* datatype, half* out, int lda, int ldb, @@ -334,7 +335,7 @@ void gemv_4bit_inference_fp32( #endif extern "C" { -#if BUILD_CUDA || BUILD_HIP +#if BUILD_CUDA || BUILD_ROCM void cdequantize_blockwise_fp16_fp4( float* code, unsigned char* A, float* absmax, half* out, int blocksize, const int n, cudaStream_t stream diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 3d170436d..0d8e64479 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -225,7 +225,7 @@ git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bits # Compile & install apt-get install -y build-essential cmake # install build tools dependencies, unless present -cmake -DCOMPUTE_BACKEND=hip -S . # Use -DBNB_ROCM_ARCH="gfx90a;gfx942" to target specific gpu arch +cmake -DCOMPUTE_BACKEND=rocm -S . # Use -DBNB_ROCM_ARCH="gfx90a;gfx942" to target specific gpu arch make pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) ``` diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index f74f05634..604b20d63 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,11 +1,12 @@ import pytest -from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path +from bitsandbytes.cextension import BNB_BACKEND, get_cuda_bnb_library_path from bitsandbytes.cuda_specs import CUDASpecs @pytest.fixture def cuda120_spec() -> CUDASpecs: + """Simulates torch+cuda12.0 and a representative Ampere-class capability.""" return CUDASpecs( cuda_version_string="120", highest_compute_capability=(8, 6), @@ -13,30 +14,42 @@ def cuda120_spec() -> CUDASpecs: ) -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") +@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend") def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): + """Without overrides, library path uses the detected CUDA 12.0 version.""" + monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") +@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend") def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): + """BNB_CUDA_VERSION=110 overrides path selection to the CUDA 11.0 binary.""" monkeypatch.setenv("BNB_CUDA_VERSION", "110") assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? -# Simulates torch+rocm7.0 (PyTorch bundled ROCm) on a system with ROCm 7.2 +@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend") +def test_get_cuda_bnb_library_path_rejects_rocm_override(monkeypatch, cuda120_spec): + """BNB_ROCM_VERSION should be rejected on CUDA with a helpful error.""" + monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) + monkeypatch.setenv("BNB_ROCM_VERSION", "72") + with pytest.raises(RuntimeError, match=r"BNB_ROCM_VERSION.*detected for CUDA"): + get_cuda_bnb_library_path(cuda120_spec) + + @pytest.fixture def rocm70_spec() -> CUDASpecs: + """Simulates torch+rocm7.0 (bundled ROCm) when the system ROCm is newer.""" return CUDASpecs( - cuda_version_string="70", # from torch.version.hip == "7.0.x" - highest_compute_capability=(0, 0), # unused for ROCm library path resolution + cuda_version_string="70", + highest_compute_capability=(0, 0), cuda_version_tuple=(7, 0), ) -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") +@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend") def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec): """Without override, library path uses PyTorch's ROCm 7.0 version.""" monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) @@ -44,7 +57,7 @@ def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec): assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm70" -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") +@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend") def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog): """BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0.""" monkeypatch.setenv("BNB_ROCM_VERSION", "72") @@ -53,20 +66,10 @@ def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog): assert "BNB_ROCM_VERSION" in caplog.text -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") +@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend") def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec): """BNB_CUDA_VERSION should be rejected on ROCm with a helpful error.""" monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) - monkeypatch.setenv("BNB_CUDA_VERSION", "72") + monkeypatch.setenv("BNB_CUDA_VERSION", "120") with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*detected for ROCm"): get_cuda_bnb_library_path(rocm70_spec) - - -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") -def test_get_rocm_bnb_library_path_rocm_override_takes_priority(monkeypatch, rocm70_spec, caplog): - """When both are set, BNB_ROCM_VERSION wins if HIP_ENVIRONMENT is True.""" - monkeypatch.setenv("BNB_ROCM_VERSION", "72") - monkeypatch.setenv("BNB_CUDA_VERSION", "72") - assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72" - assert "BNB_ROCM_VERSION" in caplog.text - assert "BNB_CUDA_VERSION" not in caplog.text