From 15bab1ba2f791907e6adf831d080f2c62c1809e8 Mon Sep 17 00:00:00 2001 From: sssshhhhhh <193317444+sssshhhhhh@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:37:47 +1100 Subject: [PATCH 1/6] AMD GPU support with ROCm HIP --- CMakeLists.txt | 92 ++++++++++++++++++++++++++++++++++ include/ctranslate2/ops/ops.h | 2 + python/ctranslate2/__init__.py | 1 + src/cuda/allocator.cc | 22 ++++++-- src/cuda/helpers.h | 14 ++++-- src/cuda/primitives.cu | 40 ++++++++++++--- src/cuda/random.h | 7 +++ src/cuda/utils.cc | 43 ++++++++++++++++ src/cuda/utils.h | 31 ++++++++++++ src/layers/common.cc | 5 ++ src/models/model.cc | 8 +++ src/ops/layer_norm_gpu.cu | 6 +++ src/ops/mean_gpu.cu | 6 +++ src/ops/multinomial_gpu.cu | 7 +++ src/ops/rms_norm_gpu.cu | 6 +++ src/ops/topk_gpu.cu | 20 ++++++++ src/ops/topp_mask_gpu.cu | 5 ++ tests/CMakeLists.txt | 26 ++++++++-- tests/benchmark_utils.h | 5 ++ 19 files changed, 330 insertions(+), 16 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 680677f40..083be912c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,7 @@ option(WITH_OPENBLAS "Compile with OpenBLAS backend" OFF) option(WITH_RUY "Compile with Ruy backend" OFF) option(WITH_CUDA "Compile with CUDA backend" OFF) option(WITH_CUDNN "Compile with cuDNN backend" OFF) +option(WITH_HIP "Compile with HIP backend" OFF) option(CUDA_DYNAMIC_LOADING "Dynamically load CUDA libraries at runtime" OFF) option(ENABLE_CPU_DISPATCH "Compile CPU kernels for multiple ISA and dispatch at runtime" ON) option(ENABLE_PROFILING "Compile with profiling support" OFF) @@ -491,6 +492,9 @@ ELSEIF (ENABLE_ADDRESS_SANITIZER) ENDIF () if (WITH_CUDA) + if(WITH_HIP) + message(FATAL_ERROR "WITH_CUDA=ON incompatible with WITH_HIP=ON") + endif() find_package(CUDA 11.0 REQUIRED) list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) if (WITH_TENSOR_PARALLEL) @@ -679,6 +683,94 @@ if (WITH_CUDA) ) +elseif(WITH_HIP) + if(WITH_TENSOR_PARALLEL) + message(FATAL_ERROR "WITH_HIP=ON incompatible with WITH_TENSOR_PARALLEL=ON") + endif() + enable_language(HIP) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}") + message(STATUS "CMAKE_HIP_ARCHITECTURES: ${CMAKE_HIP_ARCHITECTURES}") + + if(NOT DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH /opt/rocm) + else() + set(ROCM_PATH $ENV{ROCM_PATH}) + endif() + list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) + + find_package(hiprand REQUIRED) + find_package(hipblas REQUIRED) + find_package(rocprim REQUIRED) + find_package(rocthrust REQUIRED) + find_package(hipcub REQUIRED) + + list(REMOVE_ITEM SOURCES + src/ops/awq/dequantize.cc + src/ops/awq/dequantize_cpu.cc + src/ops/awq/gemm.cc + src/ops/awq/gemm_cpu.cc + src/ops/awq/gemv.cc + src/ops/awq/gemv_cpu.cc + ) + list(REMOVE_ITEM CUDA_SOURCES + src/ops/awq/gemm_gpu.cu + src/ops/awq/gemv_gpu.cu + src/ops/awq/dequantize_gpu.cu + ) + if(WITH_FLASH_ATTN) + message(FATAL_ERROR "WITH_HIP=ON incompatible with WITH_FLASH_ATTN=ON") + endif() + + set_source_files_properties(${CUDA_SOURCES} PROPERTIES LANGUAGE HIP) + set_source_files_properties( + src/cpu/allocator.cc + src/cpu/backend.cc + src/cpu/cpu_info.cc + src/cpu/cpu_isa.cc + src/cpu/kernels.cc + src/cpu/parallel.cc + src/cpu/primitives.cc + src/ops/alibi_add_cpu.cc + src/ops/bias_add_cpu.cc + src/ops/concat_split_slide_cpu.cc + src/ops/conv1d_cpu.cc + src/ops/dequantize_cpu.cc + src/ops/gather_cpu.cc + src/ops/gumbel_max_cpu.cc + src/ops/layer_norm_cpu.cc + src/ops/mean_cpu.cc + src/ops/median_filter_cpu.cc + src/ops/multinomial_cpu.cc + src/ops/quantize_cpu.cc + src/ops/rms_norm_cpu.cc + src/ops/rotary_cpu.cc + src/ops/softmax_cpu.cc + src/ops/tile_cpu.cc + src/ops/topk_cpu.cc + src/ops/topp_mask_cpu.cc + src/ops/nccl_ops_cpu.cc + PROPERTIES LANGUAGE CXX + ) + link_directories(${ROCM_PATH}/lib) + + add_definitions(-DCT2_WITH_CUDA) + add_definitions(-DCT2_USE_HIP) + + add_library(${PROJECT_NAME} + SHARED + ${SOURCES} + ${CUDA_SOURCES} + ) + + add_compile_definitions(__HIP_PLATFORM_AMD__) + add_compile_definitions(__HIP_PLATFORM_HCC__) + target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) + target_link_libraries(${PROJECT_NAME} PRIVATE hiprand roc::hipblas roc::rocprim roc::rocthrust hip::hipcub) + + set_target_properties(${PROJECT_NAME} PROPERTIES LINKER_LANGUAGE CXX) + + elseif(WITH_CUDNN) message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON") else() diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h index 2a735e394..f48e56014 100644 --- a/include/ctranslate2/ops/ops.h +++ b/include/ctranslate2/ops/ops.h @@ -40,7 +40,9 @@ #include "slide.h" #include "nccl_ops.h" #include "flash_attention.h" +#ifndef CT2_USE_HIP #include "awq/gemm.h" #include "awq/gemv.h" #include "awq/dequantize_awq.h" +#endif #include "sum.h" diff --git a/python/ctranslate2/__init__.py b/python/ctranslate2/__init__.py index b9cc58376..f34d68b7f 100644 --- a/python/ctranslate2/__init__.py +++ b/python/ctranslate2/__init__.py @@ -21,6 +21,7 @@ add_dll_directory = getattr(os, "add_dll_directory", None) if add_dll_directory is not None: add_dll_directory(package_dir) + add_dll_directory(f"{package_dir}/../_rocm_sdk_libraries_custom/bin") for library in glob.glob(os.path.join(package_dir, "*.dll")): ctypes.CDLL(library) diff --git a/src/cuda/allocator.cc b/src/cuda/allocator.cc index 2311bd008..921ed34de 100644 --- a/src/cuda/allocator.cc +++ b/src/cuda/allocator.cc @@ -7,8 +7,24 @@ #include "cuda/utils.h" #include "env.h" +#ifdef CT2_USE_HIP +#include +#include +#define cub hipcub +#define cudaGetDevice hipGetDevice +#define cudaSetDevice hipSetDevice +#define cudaFreeAsync hipFreeAsync +#define cudaMallocAsync hipMallocAsync +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaDevAttrMemoryPoolsSupported hipDeviceAttributeMemoryPoolsSupported +// Async allocactor has crashing issues on Windows +// https://github.com/OpenNMT/CTranslate2/issues/1072#issuecomment-3418768140 +#define CT2_USE_ASYNC_ALLOC !_WIN32 +#else #include #include +#define CT2_USE_ASYNC_ALLOC CUDA_VERSION >= 11020 +#endif #include namespace ctranslate2 { @@ -63,7 +79,7 @@ namespace ctranslate2 { class CudaAsyncAllocator : public Allocator { public: void* allocate(size_t size, int device_index) override { -#if CUDA_VERSION >= 11020 +#if CT2_USE_ASYNC_ALLOC int prev_device_index = -1; if (device_index >= 0) { CUDA_CHECK(cudaGetDevice(&prev_device_index)); @@ -86,7 +102,7 @@ namespace ctranslate2 { } void free(void* ptr, int device_index) override { -#if CUDA_VERSION >= 11020 +#if CT2_USE_ASYNC_ALLOC int prev_device_index = -1; if (device_index >= 0) { CUDA_CHECK(cudaGetDevice(&prev_device_index)); @@ -107,7 +123,7 @@ namespace ctranslate2 { }; static bool support_cuda_malloc_async() { -#if CUDA_VERSION < 11020 +#if !CT2_USE_ASYNC_ALLOC return false; #else for (int i = 0; i < get_gpu_count(); ++i) { diff --git a/src/cuda/helpers.h b/src/cuda/helpers.h index cf2fab812..3554296c1 100644 --- a/src/cuda/helpers.h +++ b/src/cuda/helpers.h @@ -3,20 +3,28 @@ #include #include +#ifdef CT2_USE_HIP +#include +#include +#include +#include +#define __nv_bfloat16 __hip_bfloat16 +#else #include #include +#endif #include "ctranslate2/types.h" #include "utils.h" -#if !defined(__CUDACC__) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 +#if !defined(__CUDACC__) || !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 || defined(CT2_USE_HIP) # define CUDA_CAN_USE_HALF 1 #else # define CUDA_CAN_USE_HALF 0 #endif -#if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +#if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) || defined(CT2_USE_HIP) # define CUDA_CAN_USE_BF16_MATH 1 #else # define CUDA_CAN_USE_BF16_MATH 0 @@ -416,7 +424,7 @@ namespace ctranslate2 { AccumT warpVal = defaultVal; // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1; + uint64_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1; if (threadIdx.x < C10_WARP_SIZE) { index_t lane = threadIdx.x % C10_WARP_SIZE; if (lane < blockDim.x / C10_WARP_SIZE) { diff --git a/src/cuda/primitives.cu b/src/cuda/primitives.cu index 70bcaeaeb..2f21585b1 100644 --- a/src/cuda/primitives.cu +++ b/src/cuda/primitives.cu @@ -1,7 +1,33 @@ #include "ctranslate2/primitives.h" +#ifdef CT2_USE_HIP +#include +#include +#include +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cublasSgemm hipblasSgemm +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_OP_N HIPBLAS_OP_N +#define cublasComputeType_t hipblasComputeType_t +#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F +#define CUBLAS_COMPUTE_32I HIPBLAS_COMPUTE_32I +#define CUDA_R_16F HIP_R_16F +#define CUDA_R_16BF HIP_R_16BF +#define CUDA_R_32F HIP_R_32F +#define CUDA_R_8I HIP_R_8I +#define CUDA_R_32I HIP_R_32I +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define cublasSgemmStridedBatched hipblasSgemmStridedBatched +#define cublasGemmEx hipblasGemmEx +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#else #include #include +#endif #include #include "cuda/helpers.h" @@ -493,12 +519,12 @@ namespace ctranslate2 { const void* alpha_ptr = &alpha_h; const void* beta_ptr = &beta_h; - cudaDataType_t compute_type = CUDA_R_16F; + cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; if (!cuda::use_true_fp16_gemm()) { alpha_ptr = α beta_ptr = β - compute_type = CUDA_R_32F; + compute_type = CUBLAS_COMPUTE_32F; } // cuBLAS assumes column-major storage, so swap a and b accordingly. @@ -536,7 +562,7 @@ namespace ctranslate2 { a, CUDA_R_16BF, lda, &beta, c, CUDA_R_16BF, ldc, - CUDA_R_32F, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); } @@ -564,7 +590,7 @@ namespace ctranslate2 { a, CUDA_R_8I, lda, &beta_i, c, CUDA_R_32I, ldc, - CUDA_R_32I, + CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DEFAULT)); } @@ -606,12 +632,12 @@ namespace ctranslate2 { const void* alpha_ptr = &alpha_h; const void* beta_ptr = &beta_h; - cudaDataType_t compute_type = CUDA_R_16F; + cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; if (!cuda::use_true_fp16_gemm()) { alpha_ptr = α beta_ptr = β - compute_type = CUDA_R_32F; + compute_type = CUBLAS_COMPUTE_32F; } // cuBLAS assumes column-major storage, so swap a and b accordingly. @@ -650,7 +676,7 @@ namespace ctranslate2 { &beta, c, CUDA_R_16BF, ldc, stridec, batch_size, - CUDA_R_32F, + CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); } diff --git a/src/cuda/random.h b/src/cuda/random.h index e12ae20f6..11ceba3be 100644 --- a/src/cuda/random.h +++ b/src/cuda/random.h @@ -1,6 +1,13 @@ #pragma once +#ifdef CT2_USE_HIP +#include +#define curandStatePhilox4_32_10_t hiprandStatePhilox4_32_10_t +#define curand_init hiprand_init +#define curand_uniform hiprand_uniform +#else #include +#endif namespace ctranslate2 { namespace cuda { diff --git a/src/cuda/utils.cc b/src/cuda/utils.cc index 749c4e30c..979165719 100644 --- a/src/cuda/utils.cc +++ b/src/cuda/utils.cc @@ -10,6 +10,31 @@ #include "env.h" +#ifdef CT2_USE_HIP +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED +#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED +#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE +#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH +#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR +#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED +#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR +#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED +#define CUBLAS_STATUS_LICENSE_ERROR HIPBLAS_STATUS_UNKNOWN +#define cudaStreamDefault hipStreamDefault +#define cudaGetDevice hipGetDevice +#define cudaStreamCreate hipStreamCreate +#define cudaStreamDestroy hipStreamDestroy +#define cublasCreate hipblasCreate +#define cublasDestroy hipblasDestroy +#define cublasSetStream hipblasSetStream +#define cudaMalloc hipMalloc +#define cudaFree hipFree +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaSuccess hipSuccess +#define cudaGetDeviceProperties hipGetDeviceProperties +#endif + namespace ctranslate2 { namespace cuda { @@ -177,6 +202,23 @@ namespace ctranslate2 { return *device_prop; } +#ifdef CT2_USE_HIP + // https://rocm.docs.amd.com/en/latest/reference/precision-support.html + // All archs supported by ROCm 7 support the following precisions + + bool gpu_supports_int8(int device) { + return true; + } + + bool gpu_has_int8_tensor_cores(int device) { + return true; + } + + bool gpu_has_fp16_tensor_cores(int device) { + return true; + } +#else + // See docs.nvidia.com/deeplearning/sdk/tensorrt-support-matrix/index.html // for hardware support of reduced precision. @@ -194,6 +236,7 @@ namespace ctranslate2 { const cudaDeviceProp& device_prop = get_device_properties(device); return device_prop.major >= 7; } +#endif bool have_same_compute_capability(const std::vector& devices) { if (devices.size() > 1) { diff --git a/src/cuda/utils.h b/src/cuda/utils.h index 2f8c4f5ab..7aa7bb893 100644 --- a/src/cuda/utils.h +++ b/src/cuda/utils.h @@ -2,6 +2,32 @@ #include +#ifdef CT2_USE_HIP +#include +#include +#include +#include +#ifdef CT2_WITH_TENSOR_PARALLEL + #include + #include +#endif + +#define cub hipcub +#define cudaError_t hipError_t +#define cudaSuccess hipSuccess +#define cudaGetErrorString hipGetErrorString +#define cublasStatus_t hipblasStatus_t +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define cudaStream_t hipStream_t +#define cublasHandle_t hipblasHandle_t +#define cudaDeviceProp hipDeviceProp_t + +#define cudaGetDevice hipGetDevice +#define cudaSetDevice hipSetDevice +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaStreamSynchronize hipStreamSynchronize + +#else #include #include #include @@ -13,6 +39,7 @@ #ifdef CT2_WITH_CUDNN # include #endif +#endif #include "ctranslate2/types.h" #include "ctranslate2/utils.h" @@ -107,7 +134,11 @@ namespace ctranslate2 { }; // Convenience macro to call Thrust functions with a default execution policy. +#ifdef CT2_USE_HIP +#define THRUST_CALL(FUN, ...) FUN(thrust::hip::par_nosync.on(ctranslate2::cuda::get_cuda_stream()), __VA_ARGS__) +#else #define THRUST_CALL(FUN, ...) FUN(thrust::cuda::par_nosync.on(ctranslate2::cuda::get_cuda_stream()), __VA_ARGS__) +#endif } } diff --git a/src/layers/common.cc b/src/layers/common.cc index 0e9c220cb..bd7a0965e 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -400,6 +400,10 @@ namespace ctranslate2 { if (residual) ops::Add()(*residual, output, output); } else if (_qzero && _qscale) { +#ifdef CT2_USE_HIP + (void)_activation_type; + throw std::invalid_argument("AWQ unsupported with ROCm"); +#else switch (_quant_method) { case models::QUANTIZATION_TYPE::AWQ_GEMM: if (input.dim(0) * input.dim(1) >= 1024) { @@ -431,6 +435,7 @@ namespace ctranslate2 { throw std::invalid_argument("Dense forward: invalid quantized type," "support only ct2 and awq quantization"); } +#endif } else { _gemm_op(input, *weight, output, nullptr, bias, residual); } diff --git a/src/models/model.cc b/src/models/model.cc index 5ee37d627..1d8295b0a 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -839,10 +839,18 @@ namespace ctranslate2 { if (device == Device::CUDA) { int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); auto dprops = ctranslate2::cuda::get_device_properties(device_id); +#ifdef CT2_USE_HIP + supports_flash_attention = false; +#else supports_flash_attention = dprops.major >= 8; +#endif } if (use_flash_attention && !supports_flash_attention) { +#ifdef CT2_USE_HIP + throw std::invalid_argument("FlashAttention not supported on ROCm."); +#else throw std::invalid_argument("FlashAttention only supports Ampere GPUs or newer."); +#endif } #endif diff --git a/src/ops/layer_norm_gpu.cu b/src/ops/layer_norm_gpu.cu index 4c3bb149f..b04f00999 100644 --- a/src/ops/layer_norm_gpu.cu +++ b/src/ops/layer_norm_gpu.cu @@ -159,7 +159,13 @@ namespace ctranslate2 { POSSIBILITY OF SUCH DAMAGE. */ +#ifdef CT2_USE_HIP +#include +#include +#define cub hipcub +#else #include +#endif namespace at { namespace native { diff --git a/src/ops/mean_gpu.cu b/src/ops/mean_gpu.cu index a57a679a6..ebea82bde 100644 --- a/src/ops/mean_gpu.cu +++ b/src/ops/mean_gpu.cu @@ -1,6 +1,12 @@ #include "ctranslate2/ops/mean.h" +#ifdef CT2_USE_HIP +#include +#include +#define cub hipcub +#else #include +#endif #include "type_dispatch.h" #include "cuda/helpers.h" diff --git a/src/ops/multinomial_gpu.cu b/src/ops/multinomial_gpu.cu index 90f36377a..cd84051dc 100644 --- a/src/ops/multinomial_gpu.cu +++ b/src/ops/multinomial_gpu.cu @@ -1,7 +1,14 @@ #include "ctranslate2/ops/multinomial.h" +#ifdef CT2_USE_HIP +#include +#include +#include +#define cub hipcub +#else #include #include +#endif #include "cuda/helpers.h" #include "cuda/random.h" diff --git a/src/ops/rms_norm_gpu.cu b/src/ops/rms_norm_gpu.cu index 086e6323e..9cb3dd667 100644 --- a/src/ops/rms_norm_gpu.cu +++ b/src/ops/rms_norm_gpu.cu @@ -1,6 +1,12 @@ #include "ctranslate2/ops/rms_norm.h" +#ifdef CT2_USE_HIP +#include +#include +#define cub hipcub +#else #include +#endif #include "cuda/helpers.h" #include "cuda/utils.h" diff --git a/src/ops/topk_gpu.cu b/src/ops/topk_gpu.cu index ad010fb47..3262d0d07 100644 --- a/src/ops/topk_gpu.cu +++ b/src/ops/topk_gpu.cu @@ -116,7 +116,27 @@ namespace ctranslate2 { SOFTWARE. */ +#ifdef CT2_USE_HIP +#include +#define cub hipcub +namespace hipcub { + template <> + struct FpLimits<__hip_bfloat16> // hipcub only defines hip_bfloat16 + { + static HIPCUB_HOST_DEVICE __forceinline__ __hip_bfloat16 Max() { + unsigned short max_word = 0x7F7F; + return reinterpret_cast<__hip_bfloat16 &>(max_word); + } + + static HIPCUB_HOST_DEVICE __forceinline__ __hip_bfloat16 Lowest() { + unsigned short lowest_word = 0xFF7F; + return reinterpret_cast<__hip_bfloat16 &>(lowest_word); + } + }; +} +#else #include +#endif namespace fastertransformer { diff --git a/src/ops/topp_mask_gpu.cu b/src/ops/topp_mask_gpu.cu index a4e6cb6e5..1f136a1da 100644 --- a/src/ops/topp_mask_gpu.cu +++ b/src/ops/topp_mask_gpu.cu @@ -1,6 +1,11 @@ #include "ctranslate2/ops/topp_mask.h" +#ifdef CT2_USE_HIP +#include +#define cub hipcub +#else #include +#endif #include "cuda/helpers.h" diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index dd3d911b8..12306203f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -19,7 +19,7 @@ add_executable(ctranslate2_test target_include_directories(ctranslate2_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../src ) -target_link_libraries(ctranslate2_test +target_link_libraries(ctranslate2_test PRIVATE ${PROJECT_NAME} gtest_main ) @@ -27,10 +27,30 @@ target_link_libraries(ctranslate2_test add_executable(benchmark_ops benchmark_ops.cc ) -target_link_libraries(benchmark_ops +target_link_libraries(benchmark_ops PRIVATE ${PROJECT_NAME} ) +if(NOT MSVC) + target_compile_options(benchmark_ops PRIVATE -Wno-unused-result) +endif() if(WITH_CUDA) - target_link_libraries(benchmark_ops ${CUDA_LIBRARIES}) + target_link_libraries(benchmark_ops PRIVATE ${CUDA_LIBRARIES}) + +elseif(WITH_HIP) + enable_language(HIP) + list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) + link_directories(${ROCM_PATH}/lib) + + find_package(hiprand REQUIRED) + find_package(hipblas REQUIRED) + find_package(rocprim REQUIRED) + find_package(rocthrust REQUIRED) + find_package(hipcub REQUIRED) + + add_compile_definitions(__HIP_PLATFORM_AMD__) + add_compile_definitions(__HIP_PLATFORM_HCC__) + target_include_directories(benchmark_ops PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) + target_link_libraries(benchmark_ops PRIVATE hiprand roc::hipblas roc::rocprim roc::rocthrust hip::hipcub) + endif() diff --git a/tests/benchmark_utils.h b/tests/benchmark_utils.h index 55ee77b0e..c4769df04 100644 --- a/tests/benchmark_utils.h +++ b/tests/benchmark_utils.h @@ -5,8 +5,13 @@ #include #ifdef CT2_WITH_CUDA +#ifdef CT2_USE_HIP +# include +# define SYNCHRONIZE hipDeviceSynchronize() +#else # include # define SYNCHRONIZE cudaDeviceSynchronize() +#endif #else # define SYNCHRONIZE do {} while (false) #endif From ef6eb3fa7c58834391d40710c50318c86440f2e3 Mon Sep 17 00:00:00 2001 From: sssshhhhhh <193317444+sssshhhhhh@users.noreply.github.com> Date: Wed, 28 Jan 2026 20:19:42 +1100 Subject: [PATCH 2/6] CI docker and windows whls --- .github/workflows/ci.yml | 32 ++++- docker/Dockerfile_rocm | 111 ++++++++++++++++++ docker/build_all.sh | 7 +- python/ctranslate2/__init__.py | 1 + .../prepare_build_environment_windows_rocm.sh | 45 +++++++ 5 files changed, 193 insertions(+), 3 deletions(-) create mode 100644 docker/Dockerfile_rocm create mode 100644 python/tools/prepare_build_environment_windows_rocm.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index be36d5641..7b72df68d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -219,6 +219,30 @@ jobs: name: python-wheels-${{ runner.os }}-${{ matrix.arch }} path: python/wheelhouse + build-python-wheels-rocm: + runs-on: windows-2025 + steps: + - uses: actions/checkout@v6 + with: + submodules: recursive + + - name: Build wheels + uses: pypa/cibuildwheel@v3.2.1 + with: + package-dir: python + output-dir: python/wheelhouse + env: + CIBW_ENVIRONMENT_WINDOWS: CTRANSLATE2_ROOT='${{ github.workspace }}\install' + CIBW_BEFORE_ALL_WINDOWS: bash python/tools/prepare_build_environment_windows_rocm.sh + CIBW_BEFORE_BUILD: pip install -r python/install_requirements.txt + CIBW_ARCHS: auto64 + + - name: Upload Python wheels + uses: actions/upload-artifact@v6 + with: + name: rocm-python-wheels-${{ runner.os }} + path: python/wheelhouse + # We could test the Python wheels using cibuildwheel but we prefer to run the tests outside # the build environment to ensure wheels correctly embed all dependencies. @@ -325,6 +349,10 @@ jobs: build-and-push-docker-images: runs-on: ubuntu-22.04 + strategy: + matrix: + gpu: [cuda, rocm] + steps: - uses: actions/checkout@v4 with: @@ -355,7 +383,7 @@ jobs: - name: Build Docker images run: | - ./docker/build_all.sh + ./docker/build_all.sh latest 0 ${{ matrix.gpu }} - name: Login to DockerHub if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') @@ -368,7 +396,7 @@ jobs: - name: Push Docker images if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') run: | - ./docker/build_all.sh ${GITHUB_REF##*/v} 1 + ./docker/build_all.sh ${GITHUB_REF##*/v} 1 ${{ matrix.gpu }} build-and-deploy-docs: diff --git a/docker/Dockerfile_rocm b/docker/Dockerfile_rocm new file mode 100644 index 000000000..8bbf70d7e --- /dev/null +++ b/docker/Dockerfile_rocm @@ -0,0 +1,111 @@ +FROM rocm/dev-ubuntu-22.04:7.2 AS builder + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + rocm-hip-runtime-dev \ + hipblas-common-dev \ + hipblas-dev \ + hipcub-dev \ + hiprand-dev \ + rocprim-dev \ + rocrand-dev \ + rocthrust-dev \ + python3-dev \ + python3-pip \ + wget \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /root + +ENV ONEAPI_VERSION=2025.3 +RUN wget -q https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB && \ + apt-key add *.PUB && \ + rm *.PUB && \ + echo "deb https://apt.repos.intel.com/oneapi all main" > /etc/apt/sources.list.d/oneAPI.list && \ + apt-get update && \ + apt-get install -y --no-install-recommends \ + intel-oneapi-mkl-devel-$ONEAPI_VERSION \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN python3 -m pip --no-cache-dir install cmake==3.22.* + +ENV ONEDNN_VERSION=3.10.2 +RUN wget -q https://github.com/oneapi-src/oneDNN/archive/refs/tags/v${ONEDNN_VERSION}.tar.gz && \ + tar xf *.tar.gz && \ + rm *.tar.gz && \ + cd oneDNN-* && \ + cmake -DCMAKE_BUILD_TYPE=Release -DONEDNN_LIBRARY_TYPE=STATIC -DONEDNN_BUILD_EXAMPLES=OFF -DONEDNN_BUILD_TESTS=OFF -DONEDNN_ENABLE_WORKLOAD=INFERENCE -DONEDNN_ENABLE_PRIMITIVE="CONVOLUTION;REORDER" -DONEDNN_BUILD_GRAPH=OFF . && \ + make -j$(nproc) install && \ + cd .. && \ + rm -r oneDNN-* + +ENV OPENMPI_VERSION=4.1.6 +RUN wget -q https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-${OPENMPI_VERSION}.tar.bz2 && \ + tar xf *.tar.bz2 && \ + rm *.tar.bz2 && \ + cd openmpi-* && \ + ./configure && \ + make -j$(nproc) install && \ + cd .. && \ + rm -r openmpi-* + +COPY third_party third_party +COPY cli cli +COPY include include +COPY src src +COPY cmake cmake +COPY python python +COPY CMakeLists.txt . + +ARG CXX_FLAGS +ENV CXX_FLAGS=${CXX_FLAGS:-"-msse4.1 -O3 -Wno-deprecated-literal-operator"} +ARG HIP_FLAGS +ENV HIP_FLAGS=${HIP_FLAGS:-"-O3 -Wno-deprecated-literal-operator"} +ARG HIP_ARCHITECTURES +ENV HIP_ARCHITECTURES=${HIP_ARCHITECTURES:-"gfx1030;gfx1100;gfx1101;gfx1102;gfx1150;gfx1151;gfx1200;gfx1201"} +ENV CTRANSLATE2_ROOT=/opt/ctranslate2 +ARG LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH} + +RUN mkdir build_tmp && \ + cd build_tmp && \ + cmake -DCMAKE_INSTALL_PREFIX=${CTRANSLATE2_ROOT} -DCMAKE_C_COMPILER=amdclang -DCMAKE_CXX_COMPILER=amdclang++ \ + -DWITH_HIP=ON -DWITH_MKL=ON -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP \ + -DCMAKE_HIP_ARCHITECTURES="${HIP_ARCHITECTURES}" \ + -DGPU_TARGETS="${HIP_ARCHITECTURES}" -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_FLAGS="${CXX_FLAGS}" -DCMAKE_HIP_FLAGS="${HIP_FLAGS}" \ + .. && \ + VERBOSE=1 make -j$(nproc) install + +ENV LANG=en_US.UTF-8 +COPY README.md . + +RUN cd python && \ + python3 -m pip --no-cache-dir install -r install_requirements.txt && \ + python3 setup.py bdist_wheel --dist-dir $CTRANSLATE2_ROOT + +FROM rocm/dev-ubuntu-22.04:7.2 + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + rocm-hip-libraries \ + openmpi-bin \ + libgomp1 \ + python3-pip \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +ENV CTRANSLATE2_ROOT=/opt/ctranslate2 +ARG LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CTRANSLATE2_ROOT/lib + +COPY --from=builder $CTRANSLATE2_ROOT $CTRANSLATE2_ROOT +RUN python3 -m pip --no-cache-dir install $CTRANSLATE2_ROOT/*.whl && \ + rm $CTRANSLATE2_ROOT/*.whl + +ENTRYPOINT ["/opt/ctranslate2/bin/ct2-translator"] diff --git a/docker/build_all.sh b/docker/build_all.sh index 7d39189f2..f05ea99e8 100755 --- a/docker/build_all.sh +++ b/docker/build_all.sh @@ -20,6 +20,7 @@ cd $ROOT_DIR VERSION=${1:-latest} PUSH=${2:-0} +GPU=${3:-cuda} IMAGE=ghcr.io/opennmt/ctranslate2 build() @@ -42,4 +43,8 @@ build() fi } -build Dockerfile ubuntu22.04-cuda12.2 +if [ "$GPU" == "rocm" ]; then + build Dockerfile_rocm ubuntu22.04-rocm7.2 +else + build Dockerfile ubuntu22.04-cuda12.8 +fi diff --git a/python/ctranslate2/__init__.py b/python/ctranslate2/__init__.py index f34d68b7f..849f1516a 100644 --- a/python/ctranslate2/__init__.py +++ b/python/ctranslate2/__init__.py @@ -21,6 +21,7 @@ add_dll_directory = getattr(os, "add_dll_directory", None) if add_dll_directory is not None: add_dll_directory(package_dir) + add_dll_directory(f"{package_dir}/../_rocm_sdk_core/bin") add_dll_directory(f"{package_dir}/../_rocm_sdk_libraries_custom/bin") for library in glob.glob(os.path.join(package_dir, "*.dll")): diff --git a/python/tools/prepare_build_environment_windows_rocm.sh b/python/tools/prepare_build_environment_windows_rocm.sh new file mode 100644 index 000000000..283468d56 --- /dev/null +++ b/python/tools/prepare_build_environment_windows_rocm.sh @@ -0,0 +1,45 @@ +#! /bin/bash + +set -e +set -x + +pip install --no-cache-dir \ + https://repo.radeon.com/rocm/windows/rocm-rel-7.2/rocm_sdk_core-7.2.0.dev0-py3-none-win_amd64.whl \ + https://repo.radeon.com/rocm/windows/rocm-rel-7.2/rocm_sdk_devel-7.2.0.dev0-py3-none-win_amd64.whl \ + https://repo.radeon.com/rocm/windows/rocm-rel-7.2/rocm_sdk_libraries_custom-7.2.0.dev0-py3-none-win_amd64.whl \ + https://repo.radeon.com/rocm/windows/rocm-rel-7.2/rocm-7.2.0.dev0.tar.gz +rocm-sdk init + +export ROCM_PATH=$(python -c "from rocm_sdk._devel import get_devel_root;print(get_devel_root().as_posix())") +export PATH="$ROCM_PATH:$PATH" +export CC="$ROCM_PATH/lib/llvm/bin/clang.exe" +export CXX="$ROCM_PATH/lib/llvm/bin/clang++.exe" + +export HIP_PLATFORM="amd" +export HIP_PATH="$ROCM_PATH" +export HIP_DEVICE_LIB_PATH="$ROCM_PATH/lib/llvm/amdgcn/bitcode" +export HIP_CLANG_ROOT="$ROCM_PATH/lib/llvm" +export PYTORCH_ROCM_ARCH="gfx1030;gfx1100;gfx1101;gfx1102;gfx1150;gfx1151;gfx1200;gfx1201" + +# See https://github.com/oneapi-src/oneapi-ci for installer URLs +curl --netrc-optional -L -nv -o webimage.exe https://registrationcenter-download.intel.com/akdlm/IRC_NAS/1f18901e-877d-469d-a41a-a10f11b39336/intel-oneapi-base-toolkit-2025.3.0.372_offline.exe +./webimage.exe -s -x -f webimage_extracted --log extract.log +rm webimage.exe +./webimage_extracted/bootstrapper.exe -s --action install --components="intel.oneapi.win.mkl.devel" --eula=accept -p=NEED_VS2017_INTEGRATION=0 -p=NEED_VS2019_INTEGRATION=0 --log-dir=. + +ONEDNN_VERSION=3.10.2 +curl --netrc-optional -L -O https://github.com/uxlfoundation/oneDNN/archive/refs/tags/v${ONEDNN_VERSION}.tar.gz +tar xf *.tar.gz && rm *.tar.gz +cd oneDNN-* +cmake -DCMAKE_BUILD_TYPE=Release -DONEDNN_LIBRARY_TYPE=STATIC -DONEDNN_BUILD_EXAMPLES=OFF -DONEDNN_BUILD_TESTS=OFF -DONEDNN_ENABLE_WORKLOAD=INFERENCE -DONEDNN_ENABLE_PRIMITIVE="CONVOLUTION;REORDER" -DONEDNN_BUILD_GRAPH=OFF . +cmake --build . --config Release --target install --parallel 6 +cd .. +rm -r oneDNN-* + +cmake -GNinja -DCMAKE_BUILD_TYPE=Release -S . -B build -DCMAKE_CXX_FLAGS="-Wno-deprecated-literal-operator" -DCMAKE_HIP_FLAGS="-Wno-deprecated-literal-operator" -DCMAKE_INSTALL_PREFIX=$CTRANSLATE2_ROOT -DCMAKE_PREFIX_PATH="C:/Program Files (x86)/Intel/oneAPI/compiler/latest/lib;C:/Program Files (x86)/oneDNN" -DBUILD_CLI=OFF -DWITH_DNNL=ON -DWITH_HIP=ON -DCMAKE_HIP_ARCHITECTURES="$PYTORCH_ROCM_ARCH" +cmake --build build --config Release --target install --parallel 6 --verbose +rm -r build + +cp README.md python/ +cp $CTRANSLATE2_ROOT/bin/ctranslate2.dll python/ctranslate2/ +cp "C:/Program Files (x86)/Intel/oneAPI/2025.3/bin/libiomp5md.dll" python/ctranslate2/ From e7a472746f0cae632de4ef73749c13474cedd07a Mon Sep 17 00:00:00 2001 From: sssshhhhhh <193317444+sssshhhhhh@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:06:33 +1100 Subject: [PATCH 3/6] Linux whls --- .github/workflows/ci.yml | 10 ++- .../prepare_build_environment_linux_rocm.sh | 79 +++++++++++++++++++ .../prepare_build_environment_windows_rocm.sh | 5 +- 3 files changed, 91 insertions(+), 3 deletions(-) create mode 100755 python/tools/prepare_build_environment_linux_rocm.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8317296ec..a5877bc61 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -229,7 +229,11 @@ jobs: path: python/wheelhouse build-python-wheels-rocm: - runs-on: windows-2025 + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-24.04, windows-2025] + steps: - uses: actions/checkout@v6 with: @@ -241,10 +245,14 @@ jobs: package-dir: python output-dir: python/wheelhouse env: + CIBW_ENVIRONMENT_LINUX: ROCM_PATH=/opt/rocm LD_LIBRARY_PATH=/opt/rocm/lib/llvm/lib:$LD_LIBRARY_PATH CIBW_ENVIRONMENT_WINDOWS: CTRANSLATE2_ROOT='${{ github.workspace }}\install' + CIBW_BEFORE_ALL_LINUX: python/tools/prepare_build_environment_linux_rocm.sh CIBW_BEFORE_ALL_WINDOWS: bash python/tools/prepare_build_environment_windows_rocm.sh CIBW_BEFORE_BUILD: pip install -r python/install_requirements.txt + CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28 CIBW_ARCHS: auto64 + CIBW_SKIP: "*-musllinux_*" - name: Upload Python wheels uses: actions/upload-artifact@v6 diff --git a/python/tools/prepare_build_environment_linux_rocm.sh b/python/tools/prepare_build_environment_linux_rocm.sh new file mode 100755 index 000000000..aac044f66 --- /dev/null +++ b/python/tools/prepare_build_environment_linux_rocm.sh @@ -0,0 +1,79 @@ +#! /bin/bash + +set -e +set -x + +rm -rf /host/usr/local/lib/{android,node_modules} +rm -rf /host/usr/local/.ghcup +rm -rf /host/usr/local/share/{powershell,chromium} +rm -rf /host/usr/local/julia* +rm -rf /host/usr/share/{dotnet,swift} +rm -rf /host/usr/share/az_* +rm -rf /host/usr/lib/{jvm,google-cloud-sdk} +rm -rf /host/opt/hostedtoolcache/{CodeQL,go,node,Ruby} +rm -rf /host/opt/{microsoft,az,google} +df -h + +export LIBRARY_PATH="/opt/rh/gcc-toolset-14/root/usr/lib/gcc/x86_64-redhat-linux/14:${LIBRARY_PATH:-}" + +tee /etc/yum.repos.d/rocm.repo < Date: Fri, 30 Jan 2026 19:52:59 +1100 Subject: [PATCH 4/6] Don't bundle rocm libs --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a5877bc61..69906caab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -253,6 +253,7 @@ jobs: CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28 CIBW_ARCHS: auto64 CIBW_SKIP: "*-musllinux_*" + CIBW_REPAIR_WHEEL_COMMAND_LINUX: 'auditwheel repair -w {dest_dir} --exclude "/opt/rocm/lib/*" {wheel}' - name: Upload Python wheels uses: actions/upload-artifact@v6 From e9f843bb2343391d2fc8b2bf3406429d10a0693f Mon Sep 17 00:00:00 2001 From: sssshhhhhh <193317444+sssshhhhhh@users.noreply.github.com> Date: Tue, 3 Feb 2026 00:06:03 +1100 Subject: [PATCH 5/6] Docker set LD_LIBRARY_PATH --- docker/Dockerfile_rocm | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile_rocm b/docker/Dockerfile_rocm index 8bbf70d7e..c3f4f826a 100644 --- a/docker/Dockerfile_rocm +++ b/docker/Dockerfile_rocm @@ -68,7 +68,6 @@ ENV HIP_FLAGS=${HIP_FLAGS:-"-O3 -Wno-deprecated-literal-operator"} ARG HIP_ARCHITECTURES ENV HIP_ARCHITECTURES=${HIP_ARCHITECTURES:-"gfx1030;gfx1100;gfx1101;gfx1102;gfx1150;gfx1151;gfx1200;gfx1201"} ENV CTRANSLATE2_ROOT=/opt/ctranslate2 -ARG LD_LIBRARY_PATH ENV LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH} RUN mkdir build_tmp && \ @@ -101,8 +100,8 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* ENV CTRANSLATE2_ROOT=/opt/ctranslate2 -ARG LD_LIBRARY_PATH -ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CTRANSLATE2_ROOT/lib +ENV ROCM_ROOT=/opt/rocm +ENV LD_LIBRARY_PATH=$CTRANSLATE2_ROOT/lib:$ROCM_ROOT/lib/llvm/lib:$LD_LIBRARY_PATH COPY --from=builder $CTRANSLATE2_ROOT $CTRANSLATE2_ROOT RUN python3 -m pip --no-cache-dir install $CTRANSLATE2_ROOT/*.whl && \ From 704f243e21f87b1b11cc0f2ac93ef40515469d18 Mon Sep 17 00:00:00 2001 From: sssshhhhhh <193317444+sssshhhhhh@users.noreply.github.com> Date: Tue, 3 Feb 2026 00:06:36 +1100 Subject: [PATCH 6/6] Include omp in linux whl --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 69906caab..43ab7d423 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -253,7 +253,7 @@ jobs: CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28 CIBW_ARCHS: auto64 CIBW_SKIP: "*-musllinux_*" - CIBW_REPAIR_WHEEL_COMMAND_LINUX: 'auditwheel repair -w {dest_dir} --exclude "/opt/rocm/lib/*" {wheel}' + CIBW_REPAIR_WHEEL_COMMAND_LINUX: 'auditwheel repair -w {dest_dir} --exclude "/opt/rocm/lib/lib*" {wheel}' - name: Upload Python wheels uses: actions/upload-artifact@v6