From 166309a82f0506c36c07d296b9361dbdd7f3097f Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 3 Jun 2026 16:27:24 +0000 Subject: [PATCH 1/4] Switch to device permute Signed-off-by: Mickael Ide --- .../neighbors/detail/vamana/vamana_build.cuh | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 336d81215b..6ece4c7897 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -20,12 +20,15 @@ #include #include #include +#include #include #include #include +#include #include #include +#include #include #include @@ -44,21 +47,22 @@ namespace cuvs::neighbors::vamana::detail { static const int blockD = 32; static const int maxBlocks = 10000; -// generate random permutation of inserts - TODO do this on GPU / faster +// generate random permutation of inserts on device template -void create_insert_permutation(std::vector& insert_order, uint32_t N) +void create_insert_permutation( + raft::resources const& res, + raft::device_vector_view insert_order uint64_t seed = 0) { - insert_order.resize(N); - for (uint32_t i = 0; i < N; i++) { - insert_order[i] = (IdxT)i; - } - for (uint32_t i = 0; i < N; i++) { - uint32_t temp; - uint32_t rand_idx = rand() % N; - temp = insert_order[i]; - insert_order[i] = insert_order[rand_idx]; - insert_order[rand_idx] = temp; - } + uint32_t N = insert_order.extent(0); + auto exec = raft::resource::get_thrust_policy(res); + + thrust::sequence(exec, insert_order.data_handle(), insert_order.data_handle() + N, IdxT{0}); + + auto keys = raft::make_device_vector(res, N); + raft::random::RngState rng(seed != 0 ? seed : static_cast(std::rand())); + raft::random::uniform(res, rng, keys.data_handle(), N, 0.0, 1.0); + + thrust::sort_by_key(exec, keys.data_handle(), keys.data_handle() + N, insert_order.data_handle()); } template @@ -179,8 +183,8 @@ void batched_insert_vamana( dim + align_padding)); // Create random permutation for order of node inserts into graph - std::vector insert_order; - create_insert_permutation(insert_order, (uint32_t)N); + auto insert_order = raft::make_device_vector(res, N); + create_insert_permutation(res, insert_order.view(), static_cast(N)); // Calculate the shared memory sizes of each kernel int sort_smem_size = 0; @@ -237,10 +241,10 @@ void batched_insert_vamana( int num_blocks = min(maxBlocks, step_size); // Copy ids to be inserted for this batch - raft::copy( - res, - raft::make_device_vector_view(query_ids.data_handle(), int64_t(step_size)), - raft::make_host_vector_view(insert_order.data() + start, int64_t(step_size))); + raft::copy(res, + raft::make_device_vector_view(query_ids.data_handle(), int64_t(step_size)), + raft::make_device_vector_view(insert_order.data_handle() + start, + int64_t(step_size))); set_query_ids<<>>( query_list_ptr.data_handle(), query_ids.data_handle(), step_size); From 33573496bd06f3dbecfa5fb5a9f289b06094382e Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 3 Jun 2026 19:11:57 +0000 Subject: [PATCH 2/4] Fix coderabbit review Signed-off-by: Mickael Ide --- cpp/src/neighbors/detail/vamana/vamana_build.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 6ece4c7897..c6d24039e0 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -49,9 +49,9 @@ static const int maxBlocks = 10000; // generate random permutation of inserts on device template -void create_insert_permutation( - raft::resources const& res, - raft::device_vector_view insert_order uint64_t seed = 0) +void create_insert_permutation(raft::resources const& res, + raft::device_vector_view insert_order, + uint64_t seed = 0) { uint32_t N = insert_order.extent(0); auto exec = raft::resource::get_thrust_policy(res); @@ -59,7 +59,7 @@ void create_insert_permutation( thrust::sequence(exec, insert_order.data_handle(), insert_order.data_handle() + N, IdxT{0}); auto keys = raft::make_device_vector(res, N); - raft::random::RngState rng(seed != 0 ? seed : static_cast(std::rand())); + raft::random::RngState rng(seed); raft::random::uniform(res, rng, keys.data_handle(), N, 0.0, 1.0); thrust::sort_by_key(exec, keys.data_handle(), keys.data_handle() + N, insert_order.data_handle()); From 457f4c406f73f32954a30fd3b711737d3fcbaa74 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 3 Jun 2026 20:17:58 +0000 Subject: [PATCH 3/4] Fix datatype Signed-off-by: Mickael Ide --- cpp/src/neighbors/detail/vamana/vamana_build.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index c6d24039e0..b12405f55d 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -60,7 +60,7 @@ void create_insert_permutation(raft::resources const& res, auto keys = raft::make_device_vector(res, N); raft::random::RngState rng(seed); - raft::random::uniform(res, rng, keys.data_handle(), N, 0.0, 1.0); + raft::random::uniform(res, rng, keys.data_handle(), N, 0.0f, 1.0f); thrust::sort_by_key(exec, keys.data_handle(), keys.data_handle() + N, insert_order.data_handle()); } From e8e1904db57bdcf1d5e1e7730fe2edd118f8e7eb Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 3 Jun 2026 21:30:49 +0000 Subject: [PATCH 4/4] Fix call Signed-off-by: Mickael Ide --- cpp/src/neighbors/detail/vamana/vamana_build.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index b12405f55d..236f1093b4 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -184,7 +184,7 @@ void batched_insert_vamana( // Create random permutation for order of node inserts into graph auto insert_order = raft::make_device_vector(res, N); - create_insert_permutation(res, insert_order.view(), static_cast(N)); + create_insert_permutation(res, insert_order.view()); // Calculate the shared memory sizes of each kernel int sort_smem_size = 0;