diff --git a/src/cuda/primitives.cu b/src/cuda/primitives.cu index 70bcaeaeb..7830467bb 100644 --- a/src/cuda/primitives.cu +++ b/src/cuda/primitives.cu @@ -1,5 +1,7 @@ #include "ctranslate2/primitives.h" +#include + #include #include #include @@ -82,11 +84,43 @@ namespace ctranslate2 { template<> template T primitives::sum(const T* array, dim_t size) { - return T(THRUST_CALL(thrust::reduce, - cuda::device_cast(array), - cuda::device_cast(array) + size, - cuda::device_type(), - cuda::plus>())); + using DeviceT = cuda::device_type; + + void* temp_storage = nullptr; + size_t temp_storage_bytes = 0; + + DeviceT* d_result; + CUDA_CHECK(cudaMalloc(&d_result, sizeof(DeviceT))); + + cub::DeviceReduce::Sum(temp_storage, + temp_storage_bytes, + cuda::device_cast(array), + d_result, + size, + cuda::get_cuda_stream()); + + CUDA_CHECK(cudaMalloc(&temp_storage, temp_storage_bytes)); + + cub::DeviceReduce::Sum(temp_storage, + temp_storage_bytes, + cuda::device_cast(array), + d_result, + size, + cuda::get_cuda_stream()); + + DeviceT h_result; + CUDA_CHECK(cudaMemcpyAsync(&h_result, + d_result, + sizeof(DeviceT), + cudaMemcpyDeviceToHost, + cuda::get_cuda_stream())); + + CUDA_CHECK(cudaStreamSynchronize(cuda::get_cuda_stream())); + + CUDA_CHECK(cudaFree(d_result)); + CUDA_CHECK(cudaFree(temp_storage)); + + return T(h_result); } template<>