Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions cpp/src/neighbors/detail/vamana/vamana_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/matrix/copy.cuh>
#include <raft/matrix/init.cuh>
#include <raft/random/rng.cuh>

#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <thrust/unique.h>

#include <cuvs/distance/distance.hpp>
Expand All @@ -44,21 +47,22 @@ namespace cuvs::neighbors::vamana::detail {
static const int blockD = 32;
static const int maxBlocks = 10000;

// generate random permutation of inserts - TODO do this on GPU / faster
// generate random permutation of inserts on device
template <typename IdxT>
void create_insert_permutation(std::vector<IdxT>& insert_order, uint32_t N)
void create_insert_permutation(raft::resources const& res,
raft::device_vector_view<IdxT, uint32_t> insert_order,
uint64_t seed = 0)
{
insert_order.resize(N);
for (uint32_t i = 0; i < N; i++) {
insert_order[i] = (IdxT)i;
}
for (uint32_t i = 0; i < N; i++) {
uint32_t temp;
uint32_t rand_idx = rand() % N;
temp = insert_order[i];
insert_order[i] = insert_order[rand_idx];
insert_order[rand_idx] = temp;
}
uint32_t N = insert_order.extent(0);
auto exec = raft::resource::get_thrust_policy(res);

thrust::sequence(exec, insert_order.data_handle(), insert_order.data_handle() + N, IdxT{0});

auto keys = raft::make_device_vector<float>(res, N);
raft::random::RngState rng(seed);
raft::random::uniform(res, rng, keys.data_handle(), N, 0.0f, 1.0f);

thrust::sort_by_key(exec, keys.data_handle(), keys.data_handle() + N, insert_order.data_handle());
}

template <typename IdxT>
Expand Down Expand Up @@ -179,8 +183,8 @@ void batched_insert_vamana(
dim + align_padding));

// Create random permutation for order of node inserts into graph
std::vector<IdxT> insert_order;
create_insert_permutation<IdxT>(insert_order, (uint32_t)N);
auto insert_order = raft::make_device_vector<IdxT, uint32_t>(res, N);
create_insert_permutation(res, insert_order.view());

// Calculate the shared memory sizes of each kernel
int sort_smem_size = 0;
Expand Down Expand Up @@ -237,10 +241,10 @@ void batched_insert_vamana(
int num_blocks = min(maxBlocks, step_size);

// Copy ids to be inserted for this batch
raft::copy(
res,
raft::make_device_vector_view(query_ids.data_handle(), int64_t(step_size)),
raft::make_host_vector_view<const IdxT>(insert_order.data() + start, int64_t(step_size)));
raft::copy(res,
raft::make_device_vector_view(query_ids.data_handle(), int64_t(step_size)),
raft::make_device_vector_view<const IdxT>(insert_order.data_handle() + start,
int64_t(step_size)));
set_query_ids<IdxT, accT><<<num_blocks, blockD, 0, stream>>>(
query_list_ptr.data_handle(), query_ids.data_handle(), step_size);

Expand Down
Loading