diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 336d81215b..236f1093b4 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -20,12 +20,15 @@ #include #include #include +#include #include #include #include +#include #include #include +#include #include #include @@ -44,21 +47,22 @@ namespace cuvs::neighbors::vamana::detail { static const int blockD = 32; static const int maxBlocks = 10000; -// generate random permutation of inserts - TODO do this on GPU / faster +// generate random permutation of inserts on device template -void create_insert_permutation(std::vector& insert_order, uint32_t N) +void create_insert_permutation(raft::resources const& res, + raft::device_vector_view insert_order, + uint64_t seed = 0) { - insert_order.resize(N); - for (uint32_t i = 0; i < N; i++) { - insert_order[i] = (IdxT)i; - } - for (uint32_t i = 0; i < N; i++) { - uint32_t temp; - uint32_t rand_idx = rand() % N; - temp = insert_order[i]; - insert_order[i] = insert_order[rand_idx]; - insert_order[rand_idx] = temp; - } + uint32_t N = insert_order.extent(0); + auto exec = raft::resource::get_thrust_policy(res); + + thrust::sequence(exec, insert_order.data_handle(), insert_order.data_handle() + N, IdxT{0}); + + auto keys = raft::make_device_vector(res, N); + raft::random::RngState rng(seed); + raft::random::uniform(res, rng, keys.data_handle(), N, 0.0f, 1.0f); + + thrust::sort_by_key(exec, keys.data_handle(), keys.data_handle() + N, insert_order.data_handle()); } template @@ -179,8 +183,8 @@ void batched_insert_vamana( dim + align_padding)); // Create random permutation for order of node inserts into graph - std::vector insert_order; - create_insert_permutation(insert_order, (uint32_t)N); + auto insert_order = raft::make_device_vector(res, N); + create_insert_permutation(res, insert_order.view()); // Calculate the shared memory sizes of each kernel int sort_smem_size = 0; @@ -237,10 +241,10 @@ void batched_insert_vamana( int num_blocks = min(maxBlocks, step_size); // Copy ids to be inserted for this batch - raft::copy( - res, - raft::make_device_vector_view(query_ids.data_handle(), int64_t(step_size)), - raft::make_host_vector_view(insert_order.data() + start, int64_t(step_size))); + raft::copy(res, + raft::make_device_vector_view(query_ids.data_handle(), int64_t(step_size)), + raft::make_device_vector_view(insert_order.data_handle() + start, + int64_t(step_size))); set_query_ids<<>>( query_list_ptr.data_handle(), query_ids.data_handle(), step_size);