Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions src/cuda/primitives.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "ctranslate2/primitives.h"

#include <cub/device/device_reduce.cuh>

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <thrust/device_ptr.h>
Expand Down Expand Up @@ -82,11 +84,43 @@ namespace ctranslate2 {
template<>
template <typename T>
T primitives<Device::CUDA>::sum(const T* array, dim_t size) {
return T(THRUST_CALL(thrust::reduce,
cuda::device_cast(array),
cuda::device_cast(array) + size,
cuda::device_type<T>(),
cuda::plus<cuda::device_type<T>>()));
using DeviceT = cuda::device_type<T>;

void* temp_storage = nullptr;
size_t temp_storage_bytes = 0;

DeviceT* d_result;
CUDA_CHECK(cudaMalloc(&d_result, sizeof(DeviceT)));

cub::DeviceReduce::Sum(temp_storage,
temp_storage_bytes,
cuda::device_cast(array),
d_result,
size,
cuda::get_cuda_stream());

CUDA_CHECK(cudaMalloc(&temp_storage, temp_storage_bytes));

cub::DeviceReduce::Sum(temp_storage,
temp_storage_bytes,
cuda::device_cast(array),
d_result,
size,
cuda::get_cuda_stream());

DeviceT h_result;
CUDA_CHECK(cudaMemcpyAsync(&h_result,
d_result,
sizeof(DeviceT),
cudaMemcpyDeviceToHost,
cuda::get_cuda_stream()));

CUDA_CHECK(cudaStreamSynchronize(cuda::get_cuda_stream()));

CUDA_CHECK(cudaFree(d_result));
CUDA_CHECK(cudaFree(temp_storage));

return T(h_result);
}

template<>
Expand Down