From 609b0f3db6643c0a62dd89d59202ec4a43b339c5 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 16 Feb 2026 18:52:40 +0000 Subject: [PATCH 01/22] prune kernel smem --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 41efa1686f..270e76838a 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -166,14 +166,20 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g uint64_t* const stats) { __shared__ uint32_t smem_num_detour[MAX_DEGREE]; + extern __shared__ unsigned char smem_buf[]; + IdxT* const smem_knn_iA_neighbors = reinterpret_cast(smem_buf); + uint64_t* const num_retain = stats; uint64_t* const num_full = stats + 1; const uint64_t iA = blockIdx.x + (batch_size * batch_id); if (iA >= graph_size) { return; } + + // Load this node's neighbor row into shared memory to reduce global reads for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - smem_num_detour[k] = 0; - if (knn_graph[k + ((uint64_t)graph_degree * iA)] == iA) { + smem_num_detour[k] = 0; + smem_knn_iA_neighbors[k] = knn_graph[k + ((uint64_t)graph_degree * iA)]; + if (smem_knn_iA_neighbors[k] == iA) { // Lower the priority of self-edge smem_num_detour[k] = graph_degree; } @@ -182,14 +188,14 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g // count number of detours (A->D->B) for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { - const uint64_t iD = knn_graph[kAD + (graph_degree * iA)]; + const uint64_t iD = smem_knn_iA_neighbors[kAD]; if (iD >= graph_size) { continue; } for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { // if ( kDB < kAB ) { - const uint64_t iB = knn_graph[kAB + (graph_degree * iA)]; + const uint64_t iB = smem_knn_iA_neighbors[kAB]; if (iB == iB_candidate) { atomicAdd(smem_num_detour + kAB, 1); break; @@ -1298,9 +1304,10 @@ void optimize( RAFT_CUDA_TRY(cudaMemsetAsync( dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); + const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { kern_prune - <<>>( + <<>>( d_input_graph.data_handle(), graph_size, knn_graph_degree, From a320e0e90453527e156da007bb96dc00de3898c0 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 18 Feb 2026 16:26:54 +0000 Subject: [PATCH 02/22] reduce copies within reverse graph compute --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 61 ++++++++++++------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 270e76838a..f3b0f0778e 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -244,6 +244,29 @@ __global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_ } } +// Build reverse graph from column k of output_graph (avoids per-column host fill and H2D copy). +template +__global__ void kern_make_rev_graph_column(const IdxT* const output_graph, // [graph_size, degree] + IdxT* const rev_graph, + uint32_t* const rev_graph_count, + const uint32_t graph_size, + const uint32_t degree, + const uint32_t k) +{ + const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint64_t tnum = blockDim.x * gridDim.x; + + for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { + const IdxT dest_id = output_graph[k + (static_cast(degree) * src_id)]; + if (dest_id >= graph_size) continue; + + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { + rev_graph[(static_cast(degree) * dest_id) + pos] = static_cast(src_id); + } + } +} + template __device__ __host__ LabelT get_root_label(IdxT i, const LabelT* label) { @@ -1444,32 +1467,26 @@ void optimize( graph_size * sizeof(uint32_t), raft::resource::get_cuda_stream(res))); - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = - raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); - - for (uint64_t k = 0; k < output_graph_degree; k++) { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; - dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; - } - raft::resource::sync_stream(res); - - raft::copy(d_dest_nodes.data_handle(), - dest_nodes.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); + // Copy full output graph to device once; kernel indexes by column k (no per-column H2D copy). + // TODO: depending on available device memory, this may need to be split into multiple copies. + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + raft::copy(d_output_graph.data_handle(), + output_graph_ptr, + static_cast(graph_size) * output_graph_degree, + raft::resource::get_cuda_stream(res)); - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + for (uint32_t k = 0; k < output_graph_degree; k++) { + kern_make_rev_graph_column<<>>( + d_output_graph.data_handle(), d_rev_graph.data_handle(), d_rev_graph_count.data_handle(), graph_size, - output_graph_degree); - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); + output_graph_degree, + k); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %u / %u \r", k, output_graph_degree); } raft::resource::sync_stream(res); From 6d1a6187f2cc138c4941608d4d9c746a11e0d774 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 19 Feb 2026 23:07:23 +0000 Subject: [PATCH 03/22] optimize() draft move more compute to GPU --- .../neighbors/detail/cagra/cagra_build.cuh | 2 + cpp/src/neighbors/detail/cagra/graph_core.cuh | 864 ++++++++++++------ 2 files changed, 590 insertions(+), 276 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 97d7bb1bac..152b603286 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -822,6 +822,8 @@ inline std::pair optimize_workspace_size(size_t n_rows, size_t index_size, bool mst_optimize = false) { + // TODO: MODIFY!! + // MST optimization memory (host only) size_t mst_host = n_rows * index_size; // mst_graph_num_edges if (mst_optimize) { diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index f3b0f0778e..f2cd79ecb6 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -161,8 +161,8 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g const uint32_t degree, const uint32_t batch_size, const uint32_t batch_id, - uint8_t* const detour_count, // [graph_chunk_size, graph_degree] - uint32_t* const num_no_detour_edges, // [graph_size] + uint8_t* const detour_count, // [batch_size, graph_degree] + uint32_t* const num_no_detour_edges, // [batch_size] uint64_t* const stats) { __shared__ uint32_t smem_num_detour[MAX_DEGREE]; @@ -172,7 +172,9 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g uint64_t* const num_retain = stats; uint64_t* const num_full = stats + 1; - const uint64_t iA = blockIdx.x + (batch_size * batch_id); + const uint64_t iA = blockIdx.x + (batch_size * batch_id); + const uint64_t iA_batch = iA % static_cast(batch_size); + if (iA >= graph_size) { return; } // Load this node's neighbor row into shared memory to reduce global reads @@ -208,7 +210,7 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g uint32_t num_edges_no_detour = 0; for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - detour_count[k + (graph_degree * iA)] = min(smem_num_detour[k], (uint32_t)255); + detour_count[k + (graph_degree * iA_batch)] = min(smem_num_detour[k], (uint32_t)255); if (smem_num_detour[k] == 0) { num_edges_no_detour++; } } num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); @@ -219,7 +221,7 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g num_edges_no_detour = min(num_edges_no_detour, degree); if (threadIdx.x == 0) { - num_no_detour_edges[iA] = num_edges_no_detour; + num_no_detour_edges[iA_batch] = num_edges_no_detour; atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } } @@ -244,26 +246,179 @@ __global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_ } } -// Build reverse graph from column k of output_graph (avoids per-column host fill and H2D copy). +// Select output_graph_degree neighbors with smallest detour count per node (writes to device). template -__global__ void kern_make_rev_graph_column(const IdxT* const output_graph, // [graph_size, degree] - IdxT* const rev_graph, - uint32_t* const rev_graph_count, - const uint32_t graph_size, - const uint32_t degree, - const uint32_t k) +__global__ void kern_select_smallest_detour_neighbors( + const IdxT* const knn_graph, + uint64_t graph_size, + uint64_t knn_graph_degree, + uint64_t output_graph_degree, + const uint8_t* const d_detour_count, // [batch_size, graph_degree] + IdxT* output_graph_ptr, // [batch_size, output_graph_degree] + const uint32_t batch_size, + const uint32_t batch_id) { - const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); - const uint64_t tnum = blockDim.x * gridDim.x; + // FIXME: this does not really work for num_warps > 1 + constexpr unsigned warp_mask = 0xffffffff; + const uint32_t num_warps = blockDim.x / raft::WarpSize; + extern __shared__ unsigned char smem_buf[]; + uint32_t* smem_indices = reinterpret_cast(smem_buf); + uint16_t* smem_detour_count = + reinterpret_cast(&smem_indices[knn_graph_degree * num_warps]); - for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { - const IdxT dest_id = output_graph[k + (static_cast(degree) * src_id)]; - if (dest_id >= graph_size) continue; + const uint32_t wid = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; + const uint64_t nid = static_cast(blockIdx.x) * num_warps + + (static_cast(batch_size) * batch_id * num_warps) + wid; - const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { - rev_graph[(static_cast(degree) * dest_id) + pos] = static_cast(src_id); + const uint64_t nid_batch = nid % static_cast(batch_size); + + if (nid >= graph_size) return; + + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + smem_detour_count[(knn_graph_degree * wid) + k] = + d_detour_count[nid_batch * knn_graph_degree + k]; + smem_indices[(knn_graph_degree * wid) + k] = k; + } + __syncwarp(warp_mask); + + for (uint32_t i = 0; i < output_graph_degree; i++) { + uint32_t local_min = 256; + uint32_t local_idx = 0xffffffff; + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + uint32_t c = smem_detour_count[(knn_graph_degree * wid) + k]; + if (c < local_min) { + local_min = c; + local_idx = smem_indices[(knn_graph_degree * wid) + k]; + } + } + uint32_t local_min_with_tag = (local_min << 16) | local_idx; + for (int offset = raft::WarpSize / 2; offset > 0; offset /= 2) { + uint32_t other = __shfl_down_sync(warp_mask, local_min_with_tag, offset); + local_min_with_tag = (local_min_with_tag <= other) ? local_min_with_tag : other; + } + uint32_t warp_min_tag = __shfl_sync(warp_mask, local_min_with_tag, 0); + uint32_t warp_local_idx = warp_min_tag & 0xffff; + + if (local_idx == warp_local_idx) { + output_graph_ptr[nid_batch * output_graph_degree + i] = + knn_graph[knn_graph_degree * nid + warp_local_idx]; + smem_detour_count[knn_graph_degree * wid + warp_local_idx] = 255; + } + __syncwarp(warp_mask); + } +} + +// Helper functions for merging the graph +template +__device__ unsigned int warp_pos_in_array(T val, const T* array, uint64_t num) +{ + unsigned int ret = num; + const uint32_t lane_id = threadIdx.x % 32; + for (uint64_t i = lane_id; i < num; i += 32) { + if (val == array[i]) { + ret = i; + break; + } + } + ret = __reduce_min_sync(0xffffffff, ret); + return ret; +} + +template +__device__ void thread_shift_array(T* array, uint64_t num) +{ + for (uint64_t i = num; i > 0; i--) { + array[i] = array[i - 1]; + } +} + +template +__global__ void kern_merge_graph(IdxT* output_graph, + const IdxT* const rev_graph, + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t output_graph_degree, + const IdxT* const mst_graph, + const uint32_t mst_graph_degree, + const uint32_t* const mst_graph_num_edges_ptr, + const uint32_t batch_size, + const uint32_t batch_id, + bool guarantee_connectivity, + bool* check_num_protected_edges) +{ + extern __shared__ unsigned char smem_buf[]; + IdxT* smem_sorted_output_graph = reinterpret_cast(smem_buf); + + const uint32_t wid = threadIdx.x / 32; + const uint32_t lane_id = threadIdx.x % 32; + const uint32_t num_warps = blockDim.x / 32; + const uint64_t nid = blockIdx.x * num_warps + (batch_size * batch_id * num_warps) + wid; + if (nid >= graph_size) { return; } + + if (lane_id == 0) check_num_protected_edges[0] = true; + + const auto mst_graph_num_edges = mst_graph_num_edges_ptr[nid]; + // If guarantee_connectivity == true, use a temporal list to merge the + // neighbor lists of the graphs. + if (guarantee_connectivity) { + for (uint32_t i = lane_id; i < mst_graph_degree; i += 32) { + smem_sorted_output_graph[i] = mst_graph[nid * mst_graph_degree + i]; } + __syncwarp(); + for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; + (pruned_j < output_graph_degree) && (output_j < output_graph_degree); + pruned_j++) { + const auto v = output_graph[output_graph_degree * nid + pruned_j]; + unsigned int dup = 0; + for (uint32_t m = lane_id; m < output_j; m += 32) { + if (v == smem_sorted_output_graph[m]) { + dup = 1; + break; + } + } + + unsigned int warp_dup = __ballot_sync(0xffffffff, dup); + if (warp_dup == 0) { + if (lane_id == 0) smem_sorted_output_graph[output_j] = v; + output_j++; + } + __syncwarp(); + } + } + + else { + for (uint32_t i = lane_id; i < output_graph_degree; i += 32) { + smem_sorted_output_graph[i] = output_graph[output_graph_degree * nid + i]; + } + __syncwarp(); + } + + const auto num_protected_edges = max(mst_graph_num_edges, output_graph_degree / 2); + + if (num_protected_edges > output_graph_degree) { check_num_protected_edges[0] = false; } + if (num_protected_edges == output_graph_degree) { return; } + + auto kr = min(rev_graph_count[nid], output_graph_degree); + + while (kr) { + kr -= 1; + if (rev_graph[kr + (output_graph_degree * nid)] < graph_size) { + uint64_t pos = warp_pos_in_array( + rev_graph[kr + (output_graph_degree * nid)], smem_sorted_output_graph, output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos >= output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } + if (lane_id == 0) { + thread_shift_array(smem_sorted_output_graph + num_protected_edges, num_shift); + smem_sorted_output_graph[num_protected_edges] = rev_graph[kr + (output_graph_degree * nid)]; + } + __syncwarp(); + } + } + + for (uint32_t i = lane_id; i < output_graph_degree; i += 32) { + output_graph[(output_graph_degree * nid) + i] = smem_sorted_output_graph[i]; } } @@ -737,11 +892,11 @@ void mst_opt_update_graph(IdxT* mst_graph_ptr, // an approximate MST. // * If the input kNN graph is disconnected, random connection is added to the largest cluster. // -template +template void mst_optimization(raft::resources const& res, - raft::host_matrix_view input_graph, - raft::host_matrix_view output_graph, - raft::host_vector_view mst_graph_num_edges, + InputMatrixView input_graph, + OutputMatrixView output_graph, + VectorView mst_graph_num_edges, bool use_gpu = true) { if (use_gpu) { @@ -1185,6 +1340,7 @@ void count_2hop_detours(raft::host_matrix_view k } } +// TODO allow pinned input for both knn_graph and new_graph template , raft::memory_type::host>> @@ -1213,9 +1369,10 @@ void optimize( "cagra::graph::optimize(%zu, %zu, %u)", graph_size, knn_graph_degree, output_graph_degree); // MST optimization - auto mst_graph = raft::make_host_matrix(0, 0); - auto mst_graph_num_edges = raft::make_host_vector(graph_size); + auto mst_graph = raft::make_pinned_matrix(res, 0, 0); + auto mst_graph_num_edges = raft::make_pinned_vector(res, graph_size); auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); + #pragma omp parallel for for (uint64_t i = 0; i < graph_size; i++) { mst_graph_num_edges_ptr[i] = 0; @@ -1223,10 +1380,10 @@ void optimize( if (guarantee_connectivity) { raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_connectivity"); - mst_graph = - raft::make_host_matrix(graph_size, output_graph_degree); + mst_graph = raft::make_pinned_matrix( + res, graph_size, output_graph_degree); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); - mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); + mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); for (uint64_t i = 0; i < graph_size; i++) { if (i < 8 || i >= graph_size - 8) { @@ -1235,6 +1392,37 @@ void optimize( } } + uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + + // + // If the available device memory is insufficient, do not use the GPU to count + // the number of 2-hop detours, but use the CPU. + // + // TODO: we should decide on a global strategy for this in a single place + // it comes down to input memory type and available memory which data should be copied to GPU + bool _use_gpu_prune = use_gpu; + if (_use_gpu_prune) { + try { + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size)); + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); + // TODO we also want to consider pinned memory in case we are short on memory + auto d_input_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); + } catch (std::bad_alloc& e) { + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); + _use_gpu_prune = false; + } catch (raft::logic_error& e) { + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); + _use_gpu_prune = false; + } + } + { raft::common::nvtx::range block_scope( "cagra::graph::optimize/prune"); @@ -1253,63 +1441,10 @@ void optimize( // specified number of edges are picked up for each node, starting with the // edge with the lowest number of 2-hop detours. // - auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - - // - // If the available device memory is insufficient, do not use the GPU to count - // the number of 2-hop detours, but use the CPU. - // - bool _use_gpu = use_gpu; - if (_use_gpu) { - try { - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - auto d_input_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for 2-hop node counting on GPU"); - _use_gpu = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for 2-hop node counting on GPU (logic error)"); - _use_gpu = false; - } - } - if (_use_gpu) { - // Count 2-hop detours on GPU - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-GPU"); - const double time_2hop_count_start = cur_time(); - - uint64_t num_keep __attribute__((unused)) = 0; - uint64_t num_full __attribute__((unused)) = 0; - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), - 0xff, - graph_size * knn_graph_degree * sizeof(uint8_t), - raft::resource::get_cuda_stream(res))); - - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); - - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); - + if (_use_gpu_prune) { + // Pruning on GPU RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - // Copy knn_graph over to device if necessary - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree)); - constexpr int MAX_DEGREE = 1024; if (knn_graph_degree > MAX_DEGREE) { RAFT_FAIL( @@ -1318,17 +1453,47 @@ void optimize( knn_graph_degree, MAX_DEGREE); } - const uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - const dim3 threads_prune(32, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); + const double prune_start = cur_time(); + + uint64_t num_keep __attribute__((unused)) = 0; + uint64_t num_full __attribute__((unused)) = 0; + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); RAFT_CUDA_TRY(cudaMemsetAsync( dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); - const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); + // Copy knn_graph over to device if necessary + // TODO: should we use pinned memory if we have issues fitting on GPU? + device_matrix_view_from_host d_input_graph( + res, + raft::make_host_matrix_view( + knn_graph.data_handle(), graph_size, knn_graph_degree)); + + // data structures per batch + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size)); + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + // initialize the detour_count and num_no_detour_edges for the current batch + RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), + 0xff, + batch_size * knn_graph_degree * sizeof(uint8_t), + raft::resource::get_cuda_stream(res))); + + RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), + 0x00, + batch_size * sizeof(uint32_t), + raft::resource::get_cuda_stream(res))); + + // count 2-hop detours for the current batch + const dim3 threads_prune(32, 1, 1); + const dim3 blocks_prune(batch_size, 1, 1); + const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); kern_prune <<>>( d_input_graph.data_handle(), @@ -1340,6 +1505,30 @@ void optimize( d_detour_count.data_handle(), d_num_no_detour_edges.data_handle(), dev_stats.data_handle()); + + // select smallest-detour neighbors for the current batch + const size_t select_smem_size = + (knn_graph_degree * knn_graph_degree) * (sizeof(uint16_t) + sizeof(uint32_t)); + const dim3 threads_select(32, 1, 1); + const dim3 blocks_select(batch_size, 1, 1); + kern_select_smallest_detour_neighbors + <<>>(d_input_graph.data_handle(), + graph_size, + knn_graph_degree, + output_graph_degree, + d_detour_count.data_handle(), + d_output_graph.data_handle(), + batch_size, + i_batch); + + raft::copy(output_graph_ptr, + d_output_graph.data_handle() + i_batch * batch_size * output_graph_degree, + static_cast(batch_size) * output_graph_degree, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); RAFT_LOG_DEBUG( "# Pruning kNN Graph on GPUs (%.1lf %%)\r", @@ -1348,96 +1537,93 @@ void optimize( raft::resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); - raft::copy(detour_count.data_handle(), - d_detour_count.data_handle(), - detour_count.size(), - raft::resource::get_cuda_stream(res)); - raft::copy( host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); num_keep = host_stats.data_handle()[0]; num_full = host_stats.data_handle()[1]; - const double time_2hop_count_end = cur_time(); + const double prune_end = cur_time(); RAFT_LOG_DEBUG( - "# Time for 2-hop detour counting on GPU: %.1lf sec, " + "# Time for pruning on GPU: %.1lf sec, " "avg_no_detour_edges_per_node: %.2lf/%u, " "nodes_with_no_detour_at_all_edges: %.1lf%%", - time_2hop_count_end - time_2hop_count_start, + prune_end - prune_start, (double)num_keep / graph_size, output_graph_degree, (double)num_full / graph_size * 100); } else { - // Count 2-hop detours on CPU - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); - const double time_2hop_count_start = cur_time(); + // Pruning on CPU + auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - count_2hop_detours(knn_graph, detour_count.view()); + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); + const double time_2hop_count_start = cur_time(); - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", - time_2hop_count_end - time_2hop_count_start); - } + count_2hop_detours(knn_graph, detour_count.view()); - // Create pruned kNN graph - bool invalid_neighbor_list = false; + const double time_2hop_count_end = cur_time(); + RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", + time_2hop_count_end - time_2hop_count_start); + } + bool invalid_neighbor_list = false; #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable - // count of the neighbors while increasing the target detourable count from zero. - uint64_t pk = 0; - uint32_t num_detour = 0; - for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < knn_graph_degree; k++) { - const auto num_detour_k = detour_count(i, k); - // Find the detourable count to check in the next iteration - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } - - // Store the neighbor index if its detourable count is equal to `num_detour`. - if (num_detour_k != num_detour) { continue; } + for (uint64_t i = 0; i < graph_size; i++) { + // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable + // count of the neighbors while increasing the target detourable count from zero. + uint64_t pk = 0; + uint32_t num_detour = 0; + for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { + uint32_t next_num_detour = std::numeric_limits::max(); + for (uint64_t k = 0; k < knn_graph_degree; k++) { + const auto num_detour_k = detour_count(i, k); + // Find the detourable count to check in the next iteration + if (num_detour_k > num_detour) { + next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); + } - // Check duplication and append - const auto candidate_node = knn_graph(i, k); - bool dup = false; - for (uint32_t dk = 0; dk < pk; dk++) { - if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { - dup = true; - break; + // Store the neighbor index if its detourable count is equal to `num_detour`. + if (num_detour_k != num_detour) { continue; } + + // Check duplication and append + const auto candidate_node = knn_graph(i, k); + bool dup = false; + for (uint32_t dk = 0; dk < pk; dk++) { + if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { + dup = true; + break; + } } - } - if (!dup && candidate_node < graph_size) { - output_graph_ptr[i * output_graph_degree + pk] = candidate_node; - pk += 1; + if (!dup && candidate_node < graph_size) { + output_graph_ptr[i * output_graph_degree + pk] = candidate_node; + pk += 1; + } + if (pk >= output_graph_degree) break; } if (pk >= output_graph_degree) break; - } - if (pk >= output_graph_degree) break; - if (next_num_detour == std::numeric_limits::max()) { - // There are no valid edges enough in the initial kNN graph. Break the loop here and catch - // the error at the next validation (pk != output_graph_degree). - break; + if (next_num_detour == std::numeric_limits::max()) { + // There are no valid edges enough in the initial kNN graph. Break the loop here and + // catch the error at the next validation (pk != output_graph_degree). + break; + } + num_detour = next_num_detour; + } + if (pk != output_graph_degree) { + RAFT_LOG_DEBUG( + "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " + "node %lu in the rank-based node reranking process", + output_graph_degree, + i); + invalid_neighbor_list = true; } - num_detour = next_num_detour; - } - if (pk != output_graph_degree) { - RAFT_LOG_DEBUG( - "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - i); - invalid_neighbor_list = true; } + RAFT_EXPECTS( + !invalid_neighbor_list, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); } - RAFT_EXPECTS( - !invalid_neighbor_list, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); const double time_prune_end = cur_time(); RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); @@ -1446,155 +1632,281 @@ void optimize( auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); auto rev_graph_count = raft::make_host_vector(graph_size); - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse"); + bool _use_gpu_rev_graph = use_gpu; + // TODO: should we use pinned memory if we have issues fitting on GPU? + if (_use_gpu_rev_graph) { + try { + auto d_rev_graph_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size)); + auto d_dest_nodes = + raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); + auto d_rev_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + } catch (std::bad_alloc& e) { + RAFT_LOG_DEBUG("Insufficient memory for reverse graph on GPU"); + _use_gpu_rev_graph = false; + } catch (raft::logic_error& e) { + RAFT_LOG_DEBUG("Insufficient memory for reverse graph on GPU (logic error)"); + _use_gpu_rev_graph = false; + } + } + + const double time_make_start = cur_time(); + if (_use_gpu_rev_graph) { // - // Make reverse graph + // Make reverse graph on GPU // - const double time_make_start = cur_time(); + auto d_rev_graph_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size)); device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), - 0xff, - graph_size * output_graph_degree * sizeof(IdxT), - raft::resource::get_cuda_stream(res))); + device_matrix_view_from_host d_output_graph( + res, + raft::make_host_matrix_view( + output_graph_ptr, graph_size, output_graph_degree)); - auto d_rev_graph_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); - - // Copy full output graph to device once; kernel indexes by column k (no per-column H2D copy). - // TODO: depending on available device memory, this may need to be split into multiple copies. - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); - raft::copy(d_output_graph.data_handle(), - output_graph_ptr, - static_cast(graph_size) * output_graph_degree, - raft::resource::get_cuda_stream(res)); + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/reverse"); + auto dest_nodes = raft::make_host_vector(graph_size); + auto d_dest_nodes = + raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - for (uint32_t k = 0; k < output_graph_degree; k++) { - kern_make_rev_graph_column<<>>( - d_output_graph.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree, - k); - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %u / %u \r", k, output_graph_degree); - } + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), + 0xff, + graph_size * output_graph_degree * sizeof(IdxT), + raft::resource::get_cuda_stream(res))); - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), + 0x00, + graph_size * sizeof(uint32_t), + raft::resource::get_cuda_stream(res))); - if (d_rev_graph.allocated_memory()) { - raft::copy(rev_graph.data_handle(), - d_rev_graph.data_handle(), - graph_size * output_graph_degree, + for (uint64_t k = 0; k < output_graph_degree; k++) { +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + // dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; + dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; + } + raft::resource::sync_stream(res); + + raft::copy(d_dest_nodes.data_handle(), + dest_nodes.data_handle(), + graph_size, + raft::resource::get_cuda_stream(res)); + + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); + } + + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); + + if (d_rev_graph.allocated_memory()) { + raft::copy(rev_graph.data_handle(), + d_rev_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); + } + raft::copy(rev_graph_count.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, raft::resource::get_cuda_stream(res)); + + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", + (time_make_end - time_make_start) * 1000.0); } - raft::copy(rev_graph_count.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", - (time_make_end - time_make_start) * 1000.0); - } + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/combine"); - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/combine"); - // - // Create search graphs from MST and pruned and reverse graphs - // - const double time_replace_start = cur_time(); + // Merging the prunned graph and the reverse graph + const double merge_graph_start = cur_time(); + + // Create a boolean variable on the GPU using RAFT device allocator + auto d_check_num_protected_edges = raft::make_device_scalar(res, true); + + const dim3 threads_merge(32, 1, 1); + const dim3 blocks_merge(batch_size, 1, 1); + const size_t merge_smem_size = (output_graph_degree + output_graph_degree) * sizeof(IdxT); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + kern_merge_graph + <<>>( + d_output_graph.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree, + mst_graph.data_handle(), + output_graph_degree, + mst_graph_num_edges_ptr, + batch_size, + i_batch, + guarantee_connectivity, + d_check_num_protected_edges.data_handle()); + } + + bool check_num_protected_edges = true; + raft::copy(&check_num_protected_edges, + d_check_num_protected_edges.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + + // TODO: is this required? + if (d_output_graph.allocated_memory()) { + raft::copy(output_graph_ptr, + d_output_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); + } + + const auto merge_graph_end = cur_time(); + RAFT_EXPECTS(check_num_protected_edges, + "Failed to merge the MST, pruned, and reverse edge graphs. " + "Some nodes have too " + "many MST optimization edges."); + + RAFT_LOG_DEBUG("# Time for merging graphs: %.1lf ms", + (merge_graph_end - merge_graph_start) * 1000.0); + } + } else { + { + // Make reverse graph on CPU + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/reverse"); + + auto rev_graph_ptr = rev_graph.data_handle(); + auto rev_graph_count_ptr = rev_graph_count.data_handle(); - bool check_num_protected_edges = true; #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - auto my_rev_graph = rev_graph.data_handle() + (output_graph_degree * i); - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + for (uint64_t i = 0; i < graph_size; i++) { + rev_graph_count_ptr[i] = 0; + } - // If guarantee_connectivity == true, use a temporal list to merge the neighbor lists of the - // graphs. - std::vector temp_output_neighbor_list; - if (guarantee_connectivity) { - temp_output_neighbor_list.resize(output_graph_degree); - my_out_graph = temp_output_neighbor_list.data(); - const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; - - // Set MST graph edges - for (uint32_t j = 0; j < mst_graph_num_edges; j++) { - my_out_graph[j] = mst_graph(i, j); + for (uint32_t k = 0; k < output_graph_degree; k++) { +#pragma omp parallel for + for (uint64_t src_id = 0; src_id < graph_size; src_id++) { + const IdxT dest_id = + output_graph_ptr[k + (static_cast(output_graph_degree) * src_id)]; + if (dest_id >= graph_size) continue; + uint32_t pos; +#pragma omp atomic capture + pos = rev_graph_count_ptr[dest_id]++; + if (pos < output_graph_degree) { + rev_graph_ptr[(static_cast(output_graph_degree) * dest_id) + pos] = + static_cast(src_id); + } } + } - // Set pruned graph edges - for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; - (pruned_j < output_graph_degree) && (output_j < output_graph_degree); - pruned_j++) { - const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; - - // duplication check - bool dup = false; - for (uint32_t m = 0; m < output_j; m++) { - if (v == my_out_graph[m]) { - dup = true; - break; - } + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time (CPU): %.1lf ms", + (time_make_end - time_make_start) * 1000.0); + } + + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/combine"); + // + // Create search graphs from MST and pruned and reverse graphs + // + const double time_replace_start = cur_time(); + + bool check_num_protected_edges = true; +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + auto my_rev_graph = rev_graph.data_handle() + (output_graph_degree * i); + auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + + // If guarantee_connectivity == true, use a temporal list to merge the neighbor lists of the + // graphs. + std::vector temp_output_neighbor_list; + if (guarantee_connectivity) { + temp_output_neighbor_list.resize(output_graph_degree); + my_out_graph = temp_output_neighbor_list.data(); + const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; + + // Set MST graph edges + for (uint32_t j = 0; j < mst_graph_num_edges; j++) { + my_out_graph[j] = mst_graph(i, j); } - if (!dup) { - my_out_graph[output_j] = v; - output_j++; + // Set pruned graph edges + for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; + (pruned_j < output_graph_degree) && (output_j < output_graph_degree); + pruned_j++) { + const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; + + // duplication check + bool dup = false; + for (uint32_t m = 0; m < output_j; m++) { + if (v == my_out_graph[m]) { + dup = true; + break; + } + } + + if (!dup) { + my_out_graph[output_j] = v; + output_j++; + } } } - } - const auto num_protected_edges = - std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); - if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } - if (num_protected_edges == output_graph_degree) continue; - - // Replace some edges of the output graph with edges of the reverse graph. - auto kr = std::min(rev_graph_count.data_handle()[i], output_graph_degree); - while (kr) { - kr -= 1; - if (my_rev_graph[kr] < graph_size) { - uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos >= output_graph_degree) { - num_shift = output_graph_degree - num_protected_edges - 1; + const auto num_protected_edges = + std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); + if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } + if (num_protected_edges == output_graph_degree) continue; + + // Replace some edges of the output graph with edges of the reverse graph. + auto kr = std::min(rev_graph_count.data_handle()[i], output_graph_degree); + while (kr) { + kr -= 1; + if (my_rev_graph[kr] < graph_size) { + uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos >= output_graph_degree) { + num_shift = output_graph_degree - num_protected_edges - 1; + } + shift_array(my_out_graph + num_protected_edges, num_shift); + my_out_graph[num_protected_edges] = my_rev_graph[kr]; } - shift_array(my_out_graph + num_protected_edges, num_shift); - my_out_graph[num_protected_edges] = my_rev_graph[kr]; } - } - // If guarantee_connectivity == true, move the output neighbor list from the temporal list to - // the output list. If false, the copy is not needed because my_out_graph is a pointer to the - // output buffer. - if (guarantee_connectivity) { - for (uint32_t j = 0; j < output_graph_degree; j++) { - output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; + // If guarantee_connectivity == true, move the output neighbor list from the temporal list + // to the output list. If false, the copy is not needed because my_out_graph is a pointer to + // the output buffer. + if (guarantee_connectivity) { + for (uint32_t j = 0; j < output_graph_degree; j++) { + output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; + } } } - } - RAFT_EXPECTS(check_num_protected_edges, - "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " - "many MST optimization edges."); + RAFT_EXPECTS(check_num_protected_edges, + "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " + "many MST optimization edges."); - const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", - (time_replace_end - time_replace_start) * 1000.0); + const double time_replace_end = cur_time(); + RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", + (time_replace_end - time_replace_start) * 1000.0); + } + } + // Check stats + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/stats"); /* stats */ uint64_t num_replaced_edges = 0; #pragma omp parallel for reduction(+ : num_replaced_edges) From 822faea739f9b77f13642c8201090273e27d32bc Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 20 Feb 2026 12:14:35 +0000 Subject: [PATCH 04/22] some fixes, cleanup --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 137 ++++++++++-------- 1 file changed, 77 insertions(+), 60 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index f2cd79ecb6..1b7e46e535 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -173,7 +173,7 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g uint64_t* const num_full = stats + 1; const uint64_t iA = blockIdx.x + (batch_size * batch_id); - const uint64_t iA_batch = iA % static_cast(batch_size); + const uint64_t iA_batch = blockIdx.x; if (iA >= graph_size) { return; } @@ -246,66 +246,69 @@ __global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_ } } -// Select output_graph_degree neighbors with smallest detour count per node (writes to device). -template +// Based on the detour count, select the smallest detour count and its index +// (Pruning Update Kernel) +template __global__ void kern_select_smallest_detour_neighbors( - const IdxT* const knn_graph, + const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] uint64_t graph_size, uint64_t knn_graph_degree, uint64_t output_graph_degree, - const uint8_t* const d_detour_count, // [batch_size, graph_degree] - IdxT* output_graph_ptr, // [batch_size, output_graph_degree] - const uint32_t batch_size, - const uint32_t batch_id) + uint8_t* const d_detour_count, // [batch_size, graph_degree] + IdxT* output_graph_ptr, + const uint32_t batch_size, // [batch_size, output_graph_degree] + const uint32_t batch_id, + uint32_t* const d_invalid_neighbor_list) { - // FIXME: this does not really work for num_warps > 1 - constexpr unsigned warp_mask = 0xffffffff; - const uint32_t num_warps = blockDim.x / raft::WarpSize; - extern __shared__ unsigned char smem_buf[]; - uint32_t* smem_indices = reinterpret_cast(smem_buf); - uint16_t* smem_detour_count = - reinterpret_cast(&smem_indices[knn_graph_degree * num_warps]); + assert(blockDim.x == 32); - const uint32_t wid = threadIdx.x / raft::WarpSize; - const uint32_t lane_id = threadIdx.x % raft::WarpSize; - const uint64_t nid = static_cast(blockIdx.x) * num_warps + - (static_cast(batch_size) * batch_id * num_warps) + wid; + // Allocate shared memory for detour counts and their indices + extern __shared__ IdxT smem_indices[]; + uint16_t* smem_detour_count = (uint16_t*)&smem_indices[knn_graph_degree]; - const uint64_t nid_batch = nid % static_cast(batch_size); + const uint64_t nid = blockIdx.x + (batch_size * batch_id); + const uint64_t nid_batch = blockIdx.x; - if (nid >= graph_size) return; + if (nid >= graph_size) { return; } - for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { - smem_detour_count[(knn_graph_degree * wid) + k] = - d_detour_count[nid_batch * knn_graph_degree + k]; - smem_indices[(knn_graph_degree * wid) + k] = k; + // Each uint64_t loads detour_count for its assigned k + for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { + smem_detour_count[k] = d_detour_count[nid_batch * knn_graph_degree + k]; + smem_indices[k] = knn_graph[knn_graph_degree * nid + k]; } - __syncwarp(warp_mask); + __syncwarp(); + + const unsigned warp_mask = 0xffffffff; for (uint32_t i = 0; i < output_graph_degree; i++) { - uint32_t local_min = 256; + uint32_t local_min = 255; uint32_t local_idx = 0xffffffff; - for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { - uint32_t c = smem_detour_count[(knn_graph_degree * wid) + k]; - if (c < local_min) { - local_min = c; - local_idx = smem_indices[(knn_graph_degree * wid) + k]; + for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { + if (smem_detour_count[k] < local_min) { + local_min = smem_detour_count[k]; + local_idx = k; } } - uint32_t local_min_with_tag = (local_min << 16) | local_idx; - for (int offset = raft::WarpSize / 2; offset > 0; offset /= 2) { - uint32_t other = __shfl_down_sync(warp_mask, local_min_with_tag, offset); - local_min_with_tag = (local_min_with_tag <= other) ? local_min_with_tag : other; + + uint32_t local_min_with_tag = (local_min << 16) | ((uint32_t)local_idx); + uint32_t warp_min_with_tag = __reduce_min_sync(warp_mask, local_min_with_tag); + uint32_t warp_min_count = warp_min_with_tag >> 16; + uint32_t warp_local_idx = warp_min_with_tag & 0xffff; + + if (warp_min_count == 255) { + // No valid position left; set error flag and fill remaining slots with sentinel + if (threadIdx.x == 0) { atomicExch(d_invalid_neighbor_list, 1u); } + break; } - uint32_t warp_min_tag = __shfl_sync(warp_mask, local_min_with_tag, 0); - uint32_t warp_local_idx = warp_min_tag & 0xffff; - if (local_idx == warp_local_idx) { - output_graph_ptr[nid_batch * output_graph_degree + i] = - knn_graph[knn_graph_degree * nid + warp_local_idx]; - smem_detour_count[knn_graph_degree * wid + warp_local_idx] = 255; + IdxT selected_node = smem_indices[warp_local_idx]; + + for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { + if (smem_indices[k] == selected_node) { smem_detour_count[k] = 255; } } __syncwarp(warp_mask); + + if (threadIdx.x == 0) { output_graph_ptr[nid_batch * output_graph_degree + i] = selected_node; } } } @@ -350,19 +353,18 @@ __global__ void kern_merge_graph(IdxT* output_graph, extern __shared__ unsigned char smem_buf[]; IdxT* smem_sorted_output_graph = reinterpret_cast(smem_buf); - const uint32_t wid = threadIdx.x / 32; - const uint32_t lane_id = threadIdx.x % 32; - const uint32_t num_warps = blockDim.x / 32; - const uint64_t nid = blockIdx.x * num_warps + (batch_size * batch_id * num_warps) + wid; + assert(blockDim.x == 32); + + const uint64_t nid = blockIdx.x + (batch_size * batch_id); if (nid >= graph_size) { return; } - if (lane_id == 0) check_num_protected_edges[0] = true; + if (threadIdx.x == 0) check_num_protected_edges[0] = true; const auto mst_graph_num_edges = mst_graph_num_edges_ptr[nid]; // If guarantee_connectivity == true, use a temporal list to merge the // neighbor lists of the graphs. if (guarantee_connectivity) { - for (uint32_t i = lane_id; i < mst_graph_degree; i += 32) { + for (uint32_t i = threadIdx.x; i < mst_graph_degree; i += 32) { smem_sorted_output_graph[i] = mst_graph[nid * mst_graph_degree + i]; } __syncwarp(); @@ -371,7 +373,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, pruned_j++) { const auto v = output_graph[output_graph_degree * nid + pruned_j]; unsigned int dup = 0; - for (uint32_t m = lane_id; m < output_j; m += 32) { + for (uint32_t m = threadIdx.x; m < output_j; m += 32) { if (v == smem_sorted_output_graph[m]) { dup = 1; break; @@ -380,7 +382,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, unsigned int warp_dup = __ballot_sync(0xffffffff, dup); if (warp_dup == 0) { - if (lane_id == 0) smem_sorted_output_graph[output_j] = v; + if (threadIdx.x == 0) smem_sorted_output_graph[output_j] = v; output_j++; } __syncwarp(); @@ -388,7 +390,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, } else { - for (uint32_t i = lane_id; i < output_graph_degree; i += 32) { + for (uint32_t i = threadIdx.x; i < output_graph_degree; i += 32) { smem_sorted_output_graph[i] = output_graph[output_graph_degree * nid + i]; } __syncwarp(); @@ -409,7 +411,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, if (pos < num_protected_edges) { continue; } uint64_t num_shift = pos - num_protected_edges; if (pos >= output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } - if (lane_id == 0) { + if (threadIdx.x == 0) { thread_shift_array(smem_sorted_output_graph + num_protected_edges, num_shift); smem_sorted_output_graph[num_protected_edges] = rev_graph[kr + (output_graph_degree * nid)]; } @@ -417,7 +419,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, } } - for (uint32_t i = lane_id; i < output_graph_degree; i += 32) { + for (uint32_t i = threadIdx.x; i < output_graph_degree; i += 32) { output_graph[(output_graph_degree * nid) + i] = smem_sorted_output_graph[i]; } } @@ -1477,6 +1479,7 @@ void optimize( res, large_tmp_mr, raft::make_extents(batch_size)); auto d_output_graph = raft::make_device_mdarray( res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { // initialize the detour_count and num_no_detour_edges for the current batch @@ -1507,8 +1510,7 @@ void optimize( dev_stats.data_handle()); // select smallest-detour neighbors for the current batch - const size_t select_smem_size = - (knn_graph_degree * knn_graph_degree) * (sizeof(uint16_t) + sizeof(uint32_t)); + const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); const dim3 threads_select(32, 1, 1); const dim3 blocks_select(batch_size, 1, 1); kern_select_smallest_detour_neighbors @@ -1522,10 +1524,11 @@ void optimize( d_detour_count.data_handle(), d_output_graph.data_handle(), batch_size, - i_batch); + i_batch, + d_invalid_neighbor_list.data_handle()); - raft::copy(output_graph_ptr, - d_output_graph.data_handle() + i_batch * batch_size * output_graph_degree, + raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, + d_output_graph.data_handle(), static_cast(batch_size) * output_graph_degree, raft::resource::get_cuda_stream(res)); @@ -1537,6 +1540,18 @@ void optimize( raft::resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); + uint32_t invalid_neighbor_list = 0; + raft::copy(&invalid_neighbor_list, + d_invalid_neighbor_list.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + RAFT_EXPECTS( + invalid_neighbor_list == 0, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); + raft::copy( host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); num_keep = host_stats.data_handle()[0]; @@ -1642,6 +1657,8 @@ void optimize( raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); auto d_rev_graph = raft::make_device_mdarray( res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); } catch (std::bad_alloc& e) { RAFT_LOG_DEBUG("Insufficient memory for reverse graph on GPU"); _use_gpu_rev_graph = false; @@ -1760,9 +1777,7 @@ void optimize( d_check_num_protected_edges.data_handle(), 1, raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - // TODO: is this required? if (d_output_graph.allocated_memory()) { raft::copy(output_graph_ptr, d_output_graph.data_handle(), @@ -1770,6 +1785,8 @@ void optimize( raft::resource::get_cuda_stream(res)); } + raft::resource::sync_stream(res); + const auto merge_graph_end = cur_time(); RAFT_EXPECTS(check_num_protected_edges, "Failed to merge the MST, pruned, and reverse edge graphs. " From 9b1f741ca39eb4624a08b51d54a2429fa6b08eff Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 25 Feb 2026 15:41:10 +0000 Subject: [PATCH 05/22] some fixes --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 1b7e46e535..8705f555b6 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -268,21 +268,23 @@ __global__ void kern_select_smallest_detour_neighbors( const uint64_t nid = blockIdx.x + (batch_size * batch_id); const uint64_t nid_batch = blockIdx.x; + const uint32_t maxval16 = 0x0000ffff; if (nid >= graph_size) { return; } - // Each uint64_t loads detour_count for its assigned k + // Load indices and detour counts for each neighbor; invalidate out-of-bounds entries for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { - smem_detour_count[k] = d_detour_count[nid_batch * knn_graph_degree + k]; smem_indices[k] = knn_graph[knn_graph_degree * nid + k]; + smem_detour_count[k] = (smem_indices[k] >= graph_size) + ? maxval16 + : (uint16_t)d_detour_count[nid_batch * knn_graph_degree + k]; } __syncwarp(); const unsigned warp_mask = 0xffffffff; - for (uint32_t i = 0; i < output_graph_degree; i++) { - uint32_t local_min = 255; - uint32_t local_idx = 0xffffffff; + uint32_t local_min = maxval16; + uint32_t local_idx = maxval16; for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { if (smem_detour_count[k] < local_min) { local_min = smem_detour_count[k]; @@ -295,8 +297,7 @@ __global__ void kern_select_smallest_detour_neighbors( uint32_t warp_min_count = warp_min_with_tag >> 16; uint32_t warp_local_idx = warp_min_with_tag & 0xffff; - if (warp_min_count == 255) { - // No valid position left; set error flag and fill remaining slots with sentinel + if (warp_min_count == maxval16 || warp_local_idx == maxval16) { if (threadIdx.x == 0) { atomicExch(d_invalid_neighbor_list, 1u); } break; } @@ -304,7 +305,7 @@ __global__ void kern_select_smallest_detour_neighbors( IdxT selected_node = smem_indices[warp_local_idx]; for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { - if (smem_indices[k] == selected_node) { smem_detour_count[k] = 255; } + if (smem_indices[k] == selected_node) { smem_detour_count[k] = maxval16; } } __syncwarp(warp_mask); @@ -1355,7 +1356,11 @@ void optimize( { RAFT_LOG_DEBUG( "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); + + // large temporary memory for large arrays, e.g. everything >= O(graph_size) auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + // temporary memory for small arrays, e.g. everything <= O(batchsize * graph_degree) + // auto tmp_mr = raft::resource::get_tmp_workspace_resource(res); RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), "Each input array is expected to have the same number of rows"); @@ -1527,9 +1532,12 @@ void optimize( i_batch, d_invalid_neighbor_list.data_handle()); + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, d_output_graph.data_handle(), - static_cast(batch_size) * output_graph_degree, + copy_size, raft::resource::get_cuda_stream(res)); raft::resource::sync_stream(res); From ecf3b1db009a78adaeea268ff51d1c2d79763eb9 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 27 Feb 2026 15:39:01 +0000 Subject: [PATCH 06/22] extract prune into separate function --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 497 +++++++++--------- 1 file changed, 245 insertions(+), 252 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 8705f555b6..70cd29aa4a 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -1343,6 +1343,246 @@ void count_2hop_detours(raft::host_matrix_view k } } +// +// Prune unimportant edges based on 2-hop detour counts. +// +// The edge to be retained is determined without explicitly considering distance or angle. +// Suppose the edge is the k-th edge of some node-A to node-B (A->B). Among the edges +// originating at node-A, there are k-1 edges shorter than the edge A->B. Each of these +// k-1 edges are connected to a different k-1 nodes. Among these k-1 nodes, count the +// number of nodes with edges to node-B, which is the number of 2-hop detours for the +// edge A->B. Once the number of 2-hop detours has been counted for all edges, the +// specified number of edges are picked up for each node, starting with the edge with +// the lowest number of 2-hop detours. +// +template +void prune_graph(raft::resources const& res, + InputMatrixView knn_graph, + OutputMatrixView output_graph, + bool use_gpu) +{ + const uint64_t graph_size = output_graph.extent(0); + const uint64_t knn_graph_degree = knn_graph.extent(1); + const uint64_t output_graph_degree = output_graph.extent(1); + auto output_graph_ptr = output_graph.data_handle(); + + auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + + uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + + bool use_gpu_prune = use_gpu; + if (use_gpu_prune) { + try { + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size)); + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); + auto d_input_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); + } catch (std::bad_alloc& e) { + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); + use_gpu_prune = false; + } catch (raft::logic_error& e) { + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); + use_gpu_prune = false; + } + } + + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune"); + const double time_prune_start = cur_time(); + + if (use_gpu_prune) { + RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); + + constexpr int MAX_DEGREE = 1024; + if (knn_graph_degree > MAX_DEGREE) { + RAFT_FAIL( + "The degree of input knn graph is too large (%zu). " + "It must be equal to or smaller than %d.", + knn_graph_degree, + MAX_DEGREE); + } + + const double prune_start = cur_time(); + + uint64_t num_keep __attribute__((unused)) = 0; + uint64_t num_full __attribute__((unused)) = 0; + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); + RAFT_CUDA_TRY(cudaMemsetAsync( + dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); + + device_matrix_view_from_host d_input_graph( + res, + raft::make_host_matrix_view( + knn_graph.data_handle(), graph_size, knn_graph_degree)); + + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size)); + auto d_output_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), + 0xff, + batch_size * knn_graph_degree * sizeof(uint8_t), + raft::resource::get_cuda_stream(res))); + + RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), + 0x00, + batch_size * sizeof(uint32_t), + raft::resource::get_cuda_stream(res))); + + const dim3 threads_prune(32, 1, 1); + const dim3 blocks_prune(batch_size, 1, 1); + const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); + kern_prune + <<>>( + d_input_graph.data_handle(), + graph_size, + knn_graph_degree, + output_graph_degree, + batch_size, + i_batch, + d_detour_count.data_handle(), + d_num_no_detour_edges.data_handle(), + dev_stats.data_handle()); + + const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); + const dim3 threads_select(32, 1, 1); + const dim3 blocks_select(batch_size, 1, 1); + kern_select_smallest_detour_neighbors + <<>>( + d_input_graph.data_handle(), + graph_size, + knn_graph_degree, + output_graph_degree, + d_detour_count.data_handle(), + d_output_graph.data_handle(), + batch_size, + i_batch, + d_invalid_neighbor_list.data_handle()); + + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; + raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, + d_output_graph.data_handle(), + copy_size, + raft::resource::get_cuda_stream(res)); + + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG( + "# Pruning kNN Graph on GPUs (%.1lf %%)\r", + (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); + } + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); + + uint32_t invalid_neighbor_list = 0; + raft::copy(&invalid_neighbor_list, + d_invalid_neighbor_list.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + RAFT_EXPECTS( + invalid_neighbor_list == 0, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); + + raft::copy( + host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); + num_keep = host_stats.data_handle()[0]; + num_full = host_stats.data_handle()[1]; + + const double prune_end = cur_time(); + RAFT_LOG_DEBUG( + "# Time for pruning on GPU: %.1lf sec, " + "avg_no_detour_edges_per_node: %.2lf/%u, " + "nodes_with_no_detour_at_all_edges: %.1lf%%", + prune_end - prune_start, + (double)num_keep / graph_size, + output_graph_degree, + (double)num_full / graph_size * 100); + } else { + auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); + + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); + const double time_2hop_count_start = cur_time(); + + auto knn_graph_view = raft::make_host_matrix_view( + knn_graph.data_handle(), knn_graph.extent(0), knn_graph.extent(1)); + count_2hop_detours(knn_graph_view, detour_count.view()); + + const double time_2hop_count_end = cur_time(); + RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", + time_2hop_count_end - time_2hop_count_start); + } + bool invalid_neighbor_list = false; +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + uint64_t pk = 0; + uint32_t num_detour = 0; + for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { + uint32_t next_num_detour = std::numeric_limits::max(); + for (uint64_t k = 0; k < knn_graph_degree; k++) { + const auto num_detour_k = detour_count(i, k); + if (num_detour_k > num_detour) { + next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); + } + + if (num_detour_k != num_detour) { continue; } + + const auto candidate_node = knn_graph(i, k); + bool dup = false; + for (uint32_t dk = 0; dk < pk; dk++) { + if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { + dup = true; + break; + } + } + if (!dup && candidate_node < graph_size) { + output_graph_ptr[i * output_graph_degree + pk] = candidate_node; + pk += 1; + } + if (pk >= output_graph_degree) break; + } + if (pk >= output_graph_degree) break; + + if (next_num_detour == std::numeric_limits::max()) { break; } + num_detour = next_num_detour; + } + if (pk != output_graph_degree) { + RAFT_LOG_DEBUG( + "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " + "node %lu in the rank-based node reranking process", + output_graph_degree, + i); + invalid_neighbor_list = true; + } + } + RAFT_EXPECTS( + !invalid_neighbor_list, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); + } + + const double time_prune_end = cur_time(); + RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); +} + // TODO allow pinned input for both knn_graph and new_graph template (graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - - // - // If the available device memory is insufficient, do not use the GPU to count - // the number of 2-hop detours, but use the CPU. - // - // TODO: we should decide on a global strategy for this in a single place - // it comes down to input memory type and available memory which data should be copied to GPU - bool _use_gpu_prune = use_gpu; - if (_use_gpu_prune) { - try { - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size)); - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); - // TODO we also want to consider pinned memory in case we are short on memory - auto d_input_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); - _use_gpu_prune = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); - _use_gpu_prune = false; - } - } - - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune"); - const double time_prune_start = cur_time(); - - // - // Prune unimportant edges. - // - // The edge to be retained is determined without explicitly considering - // distance or angle. Suppose the edge is the k-th edge of some node-A to - // node-B (A->B). Among the edges originating at node-A, there are k-1 edges - // shorter than the edge A->B. Each of these k-1 edges are connected to a - // different k-1 nodes. Among these k-1 nodes, count the number of nodes with - // edges to node-B, which is the number of 2-hop detours for the edge A->B. - // Once the number of 2-hop detours has been counted for all edges, the - // specified number of edges are picked up for each node, starting with the - // edge with the lowest number of 2-hop detours. - // - if (_use_gpu_prune) { - // Pruning on GPU - RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - - constexpr int MAX_DEGREE = 1024; - if (knn_graph_degree > MAX_DEGREE) { - RAFT_FAIL( - "The degree of input knn graph is too large (%zu). " - "It must be equal to or smaller than %d.", - knn_graph_degree, - MAX_DEGREE); - } - - const double prune_start = cur_time(); - - uint64_t num_keep __attribute__((unused)) = 0; - uint64_t num_full __attribute__((unused)) = 0; - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); - RAFT_CUDA_TRY(cudaMemsetAsync( - dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); - - // Copy knn_graph over to device if necessary - // TODO: should we use pinned memory if we have issues fitting on GPU? - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree)); - - // data structures per batch - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size)); - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); - auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); - - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - // initialize the detour_count and num_no_detour_edges for the current batch - RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), - 0xff, - batch_size * knn_graph_degree * sizeof(uint8_t), - raft::resource::get_cuda_stream(res))); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), - 0x00, - batch_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); - - // count 2-hop detours for the current batch - const dim3 threads_prune(32, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); - const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); - kern_prune - <<>>( - d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - batch_size, - i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), - dev_stats.data_handle()); - - // select smallest-detour neighbors for the current batch - const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); - const dim3 threads_select(32, 1, 1); - const dim3 blocks_select(batch_size, 1, 1); - kern_select_smallest_detour_neighbors - <<>>(d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - d_detour_count.data_handle(), - d_output_graph.data_handle(), - batch_size, - i_batch, - d_invalid_neighbor_list.data_handle()); - - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, - d_output_graph.data_handle(), - copy_size, - raft::resource::get_cuda_stream(res)); - - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG( - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); - } - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - uint32_t invalid_neighbor_list = 0; - raft::copy(&invalid_neighbor_list, - d_invalid_neighbor_list.data_handle(), - 1, - raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - RAFT_EXPECTS( - invalid_neighbor_list == 0, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); - - raft::copy( - host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); - num_keep = host_stats.data_handle()[0]; - num_full = host_stats.data_handle()[1]; - - const double prune_end = cur_time(); - RAFT_LOG_DEBUG( - "# Time for pruning on GPU: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%", - prune_end - prune_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); - } else { - // Pruning on CPU - auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); - const double time_2hop_count_start = cur_time(); - - count_2hop_detours(knn_graph, detour_count.view()); - - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", - time_2hop_count_end - time_2hop_count_start); - } - bool invalid_neighbor_list = false; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable - // count of the neighbors while increasing the target detourable count from zero. - uint64_t pk = 0; - uint32_t num_detour = 0; - for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < knn_graph_degree; k++) { - const auto num_detour_k = detour_count(i, k); - // Find the detourable count to check in the next iteration - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } - - // Store the neighbor index if its detourable count is equal to `num_detour`. - if (num_detour_k != num_detour) { continue; } - - // Check duplication and append - const auto candidate_node = knn_graph(i, k); - bool dup = false; - for (uint32_t dk = 0; dk < pk; dk++) { - if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { - dup = true; - break; - } - } - if (!dup && candidate_node < graph_size) { - output_graph_ptr[i * output_graph_degree + pk] = candidate_node; - pk += 1; - } - if (pk >= output_graph_degree) break; - } - if (pk >= output_graph_degree) break; - - if (next_num_detour == std::numeric_limits::max()) { - // There are no valid edges enough in the initial kNN graph. Break the loop here and - // catch the error at the next validation (pk != output_graph_degree). - break; - } - num_detour = next_num_detour; - } - if (pk != output_graph_degree) { - RAFT_LOG_DEBUG( - "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - i); - invalid_neighbor_list = true; - } - } - RAFT_EXPECTS( - !invalid_neighbor_list, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); - } - - const double time_prune_end = cur_time(); - RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); - } + prune_graph(res, knn_graph, new_graph, use_gpu); auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); auto rev_graph_count = raft::make_host_vector(graph_size); @@ -1760,6 +1749,10 @@ void optimize( // Create a boolean variable on the GPU using RAFT device allocator auto d_check_num_protected_edges = raft::make_device_scalar(res, true); + uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + const dim3 threads_merge(32, 1, 1); const dim3 blocks_merge(batch_size, 1, 1); const size_t merge_smem_size = (output_graph_degree + output_graph_degree) * sizeof(IdxT); From 972d278c77c05add60e4518150e00b0f0f7898cf Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 2 Mar 2026 14:41:22 +0000 Subject: [PATCH 07/22] extract optimize components --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 1174 +++++++++-------- 1 file changed, 616 insertions(+), 558 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 70cd29aa4a..713b03ca20 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -674,6 +674,316 @@ void shift_array(T* array, uint64_t num) array[i] = array[i - 1]; } } + +template +void log_replaced_edges_stats(const IdxT* output_graph_ptr, + uint64_t graph_size, + uint64_t output_graph_degree) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/stats"); + uint64_t num_replaced_edges = 0; +#pragma omp parallel for reduction(+ : num_replaced_edges) + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; + const uint64_t pos = + pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); + if (pos == output_graph_degree) { num_replaced_edges += 1; } + } + } + RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", + (double)num_replaced_edges / graph_size); +} + +template +void log_incoming_edges_histogram(const IdxT* output_graph_ptr, + uint64_t graph_size, + uint64_t output_graph_degree) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/check_edges"); + auto in_edge_count = raft::make_host_vector(graph_size); + auto in_edge_count_ptr = in_edge_count.data_handle(); +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + in_edge_count_ptr[i] = 0; + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; + if (j >= graph_size) continue; +#pragma omp atomic + in_edge_count_ptr[j] += 1; + } + } + auto hist = raft::make_host_vector(output_graph_degree); + auto hist_ptr = hist.data_handle(); + for (uint64_t k = 0; k < output_graph_degree; k++) { + hist_ptr[k] = 0; + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + uint32_t count = in_edge_count_ptr[i]; + if (count >= output_graph_degree) continue; +#pragma omp atomic + hist_ptr[count] += 1; + } + RAFT_LOG_DEBUG("# Histogram for number of incoming edges\n"); + uint32_t sum_hist = 0; + for (uint64_t k = 0; k < output_graph_degree; k++) { + sum_hist += hist_ptr[k]; + RAFT_LOG_DEBUG("# %3lu, %8u, %lf, (%8u, %lf)\n", + k, + hist_ptr[k], + (double)hist_ptr[k] / graph_size, + sum_hist, + (double)sum_hist / graph_size); + } +} + +template +void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, + uint64_t graph_size, + uint64_t output_graph_degree) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/check_duplicates"); + uint64_t num_dup = 0; + uint64_t num_oor = 0; +#pragma omp parallel for reduction(+ : num_dup) reduction(+ : num_oor) + for (uint64_t i = 0; i < graph_size; i++) { + auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + for (uint32_t j = 0; j < output_graph_degree; j++) { + const auto neighbor_a = my_out_graph[j]; + + if (neighbor_a > graph_size) { + num_oor++; + continue; + } + + for (uint32_t k = j + 1; k < output_graph_degree; k++) { + const auto neighbor_b = my_out_graph[k]; + if (neighbor_a == neighbor_b) { num_dup++; } + } + } + } + RAFT_EXPECTS( + num_dup == 0, "%lu duplicated node(s) are found in the generated CAGRA graph", num_dup); + RAFT_EXPECTS( + num_oor == 0, "%lu out-of-range index node(s) are found in the generated CAGRA graph", num_oor); +} + +template +void merge_graph_gpu(raft::resources const& res, + IdxT* output_graph_ptr, + const IdxT* d_rev_graph, + uint32_t* d_rev_graph_count, + const IdxT* mst_graph_ptr, + const uint32_t* mst_graph_num_edges_ptr, + uint64_t graph_size, + uint64_t output_graph_degree, + bool guarantee_connectivity) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/combine"); + + const double merge_graph_start = cur_time(); + + device_matrix_view_from_host d_output_graph( + res, + raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree)); + + auto d_check_num_protected_edges = raft::make_device_scalar(res, true); + + uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + + const dim3 threads_merge(32, 1, 1); + const dim3 blocks_merge(batch_size, 1, 1); + const size_t merge_smem_size = (output_graph_degree + output_graph_degree) * sizeof(IdxT); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + kern_merge_graph + <<>>( + d_output_graph.data_handle(), + d_rev_graph, + d_rev_graph_count, + static_cast(graph_size), + static_cast(output_graph_degree), + mst_graph_ptr, + static_cast(output_graph_degree), + mst_graph_num_edges_ptr, + batch_size, + i_batch, + guarantee_connectivity, + d_check_num_protected_edges.data_handle()); + } + + bool check_num_protected_edges = true; + raft::copy(&check_num_protected_edges, + d_check_num_protected_edges.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + + if (d_output_graph.allocated_memory()) { + raft::copy(output_graph_ptr, + d_output_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); + } + + raft::resource::sync_stream(res); + + const auto merge_graph_end = cur_time(); + RAFT_EXPECTS(check_num_protected_edges, + "Failed to merge the MST, pruned, and reverse edge graphs. " + "Some nodes have too " + "many MST optimization edges."); + + RAFT_LOG_DEBUG("# Time for merging graphs: %.1lf ms", + (merge_graph_end - merge_graph_start) * 1000.0); +} + +template +void merge_graph_cpu(IdxT* output_graph_ptr, + const IdxT* rev_graph_ptr, + const uint32_t* rev_graph_count_ptr, + const IdxT* mst_graph_ptr, + const uint32_t* mst_graph_num_edges_ptr, + uint64_t graph_size, + uint64_t output_graph_degree, + bool guarantee_connectivity) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/combine"); + + const double time_replace_start = cur_time(); + + bool check_num_protected_edges = true; +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + auto my_rev_graph = rev_graph_ptr + (output_graph_degree * i); + auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + + std::vector temp_output_neighbor_list; + if (guarantee_connectivity) { + temp_output_neighbor_list.resize(output_graph_degree); + my_out_graph = temp_output_neighbor_list.data(); + const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; + + for (uint32_t j = 0; j < mst_graph_num_edges; j++) { + my_out_graph[j] = mst_graph_ptr[i * output_graph_degree + j]; + } + + for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; + (pruned_j < output_graph_degree) && (output_j < output_graph_degree); + pruned_j++) { + const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; + + bool dup = false; + for (uint32_t m = 0; m < output_j; m++) { + if (v == my_out_graph[m]) { + dup = true; + break; + } + } + + if (!dup) { + my_out_graph[output_j] = v; + output_j++; + } + } + } + + const auto num_protected_edges = + std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); + if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } + if (num_protected_edges == output_graph_degree) continue; + + auto kr = std::min(rev_graph_count_ptr[i], output_graph_degree); + while (kr) { + kr -= 1; + if (my_rev_graph[kr] < graph_size) { + uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos >= output_graph_degree) { + num_shift = output_graph_degree - num_protected_edges - 1; + } + shift_array(my_out_graph + num_protected_edges, num_shift); + my_out_graph[num_protected_edges] = my_rev_graph[kr]; + } + } + + if (guarantee_connectivity) { + for (uint32_t j = 0; j < output_graph_degree; j++) { + output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; + } + } + } + RAFT_EXPECTS(check_num_protected_edges, + "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " + "many MST optimization edges."); + + const double time_replace_end = cur_time(); + RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", + (time_replace_end - time_replace_start) * 1000.0); +} + +template +void make_reverse_graph_gpu(raft::resources const& res, + IdxT* d_rev_graph, + uint32_t* d_rev_graph_count, + raft::host_matrix_view new_graph) +{ + const uint64_t graph_size = new_graph.extent(0); + const uint64_t output_graph_degree = new_graph.extent(1); + const IdxT* output_graph_ptr = new_graph.data_handle(); + + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/reverse"); + + auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + auto dest_nodes = raft::make_host_vector(graph_size); + auto d_dest_nodes = + raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); + + RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph, + 0xff, + graph_size * output_graph_degree * sizeof(IdxT), + raft::resource::get_cuda_stream(res))); + + RAFT_CUDA_TRY(cudaMemsetAsync( + d_rev_graph_count, 0x00, graph_size * sizeof(uint32_t), raft::resource::get_cuda_stream(res))); + + for (uint64_t k = 0; k < output_graph_degree; k++) { +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; + } + raft::resource::sync_stream(res); + + raft::copy(d_dest_nodes.data_handle(), + dest_nodes.data_handle(), + graph_size, + raft::resource::get_cuda_stream(res)); + + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph, + d_rev_graph_count, + static_cast(graph_size), + static_cast(output_graph_degree)); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %lu \r", k, output_graph_degree); + } + + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); +} } // namespace template k // specified number of edges are picked up for each node, starting with the edge with // the lowest number of 2-hop detours. // -template -void prune_graph(raft::resources const& res, - InputMatrixView knn_graph, - OutputMatrixView output_graph, - bool use_gpu) +template +void prune_graph_gpu(raft::resources const& res, + IdxT* knn_graph_ptr, + uint64_t graph_size, + uint64_t knn_graph_degree, + IdxT* output_graph_ptr, + uint64_t output_graph_degree) { - const uint64_t graph_size = output_graph.extent(0); - const uint64_t knn_graph_degree = knn_graph.extent(1); - const uint64_t output_graph_degree = output_graph.extent(1); - auto output_graph_ptr = output_graph.data_handle(); - - auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + auto default_ws_mr = raft::resource::get_workspace_resource(res); uint32_t batch_size = std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - bool use_gpu_prune = use_gpu; - if (use_gpu_prune) { - try { - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size)); - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); - auto d_input_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); - use_gpu_prune = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); - use_gpu_prune = false; - } - } - - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune"); - const double time_prune_start = cur_time(); + RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - if (use_gpu_prune) { - RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); + constexpr int MAX_DEGREE = 1024; + if (knn_graph_degree > MAX_DEGREE) { + RAFT_FAIL( + "The degree of input knn graph is too large (%zu). " + "It must be equal to or smaller than %d.", + knn_graph_degree, + MAX_DEGREE); + } - constexpr int MAX_DEGREE = 1024; - if (knn_graph_degree > MAX_DEGREE) { - RAFT_FAIL( - "The degree of input knn graph is too large (%zu). " - "It must be equal to or smaller than %d.", + const double prune_start = cur_time(); + + uint64_t num_keep __attribute__((unused)) = 0; + uint64_t num_full __attribute__((unused)) = 0; + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); + RAFT_CUDA_TRY(cudaMemsetAsync( + dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); + + device_matrix_view_from_host d_input_graph( + res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree)); + + auto d_detour_count = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(batch_size, knn_graph_degree)); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(batch_size)); + auto d_output_graph = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(batch_size, output_graph_degree)); + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), + 0xff, + batch_size * knn_graph_degree * sizeof(uint8_t), + raft::resource::get_cuda_stream(res))); + + RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), + 0x00, + batch_size * sizeof(uint32_t), + raft::resource::get_cuda_stream(res))); + + const dim3 threads_prune(32, 1, 1); + const dim3 blocks_prune(batch_size, 1, 1); + const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); + kern_prune + <<>>( + d_input_graph.data_handle(), + graph_size, knn_graph_degree, - MAX_DEGREE); - } + output_graph_degree, + batch_size, + i_batch, + d_detour_count.data_handle(), + d_num_no_detour_edges.data_handle(), + dev_stats.data_handle()); + + const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); + const dim3 threads_select(32, 1, 1); + const dim3 blocks_select(batch_size, 1, 1); + kern_select_smallest_detour_neighbors + <<>>( + d_input_graph.data_handle(), + graph_size, + knn_graph_degree, + output_graph_degree, + d_detour_count.data_handle(), + d_output_graph.data_handle(), + batch_size, + i_batch, + d_invalid_neighbor_list.data_handle()); + + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; + raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, + d_output_graph.data_handle(), + copy_size, + raft::resource::get_cuda_stream(res)); - const double prune_start = cur_time(); + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG( + "# Pruning kNN Graph on GPUs (%.1lf %%)\r", + (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); + } + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); - uint64_t num_keep __attribute__((unused)) = 0; - uint64_t num_full __attribute__((unused)) = 0; - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); - RAFT_CUDA_TRY(cudaMemsetAsync( - dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, raft::resource::get_cuda_stream(res))); + uint32_t invalid_neighbor_list = 0; + raft::copy(&invalid_neighbor_list, + d_invalid_neighbor_list.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + RAFT_EXPECTS( + invalid_neighbor_list == 0, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree)); + raft::copy( + host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); + num_keep = host_stats.data_handle()[0]; + num_full = host_stats.data_handle()[1]; - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size)); - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(batch_size, output_graph_degree)); - auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); - - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), - 0xff, - batch_size * knn_graph_degree * sizeof(uint8_t), - raft::resource::get_cuda_stream(res))); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), - 0x00, - batch_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); - - const dim3 threads_prune(32, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); - const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); - kern_prune - <<>>( - d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - batch_size, - i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), - dev_stats.data_handle()); - - const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); - const dim3 threads_select(32, 1, 1); - const dim3 blocks_select(batch_size, 1, 1); - kern_select_smallest_detour_neighbors - <<>>( - d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - d_detour_count.data_handle(), - d_output_graph.data_handle(), - batch_size, - i_batch, - d_invalid_neighbor_list.data_handle()); - - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, - d_output_graph.data_handle(), - copy_size, - raft::resource::get_cuda_stream(res)); + const double prune_end = cur_time(); + RAFT_LOG_DEBUG( + "# Time for pruning on GPU: %.1lf sec, " + "avg_no_detour_edges_per_node: %.2lf/%u, " + "nodes_with_no_detour_at_all_edges: %.1lf%%", + prune_end - prune_start, + (double)num_keep / graph_size, + output_graph_degree, + (double)num_full / graph_size * 100); +} - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG( - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); - } - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); +template +void prune_graph_cpu(IdxT* knn_graph_ptr, + uint64_t graph_size, + uint64_t knn_graph_degree, + IdxT* output_graph_ptr, + uint64_t output_graph_degree) +{ + auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - uint32_t invalid_neighbor_list = 0; - raft::copy(&invalid_neighbor_list, - d_invalid_neighbor_list.data_handle(), - 1, - raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - RAFT_EXPECTS( - invalid_neighbor_list == 0, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); - - raft::copy( - host_stats.data_handle(), dev_stats.data_handle(), 2, raft::resource::get_cuda_stream(res)); - num_keep = host_stats.data_handle()[0]; - num_full = host_stats.data_handle()[1]; - - const double prune_end = cur_time(); - RAFT_LOG_DEBUG( - "# Time for pruning on GPU: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%", - prune_end - prune_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); - } else { - auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); + auto knn_graph_view = + raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree); - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); - const double time_2hop_count_start = cur_time(); + { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); + const double time_2hop_count_start = cur_time(); - auto knn_graph_view = raft::make_host_matrix_view( - knn_graph.data_handle(), knn_graph.extent(0), knn_graph.extent(1)); - count_2hop_detours(knn_graph_view, detour_count.view()); + count_2hop_detours(knn_graph_view, detour_count.view()); - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", - time_2hop_count_end - time_2hop_count_start); - } - bool invalid_neighbor_list = false; + const double time_2hop_count_end = cur_time(); + RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", + time_2hop_count_end - time_2hop_count_start); + } + bool invalid_neighbor_list = false; #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - uint64_t pk = 0; - uint32_t num_detour = 0; - for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < knn_graph_degree; k++) { - const auto num_detour_k = detour_count(i, k); - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } + for (uint64_t i = 0; i < graph_size; i++) { + uint64_t pk = 0; + uint32_t num_detour = 0; + for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { + uint32_t next_num_detour = std::numeric_limits::max(); + for (uint64_t k = 0; k < knn_graph_degree; k++) { + const auto num_detour_k = detour_count(i, k); + if (num_detour_k > num_detour) { + next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); + } - if (num_detour_k != num_detour) { continue; } + if (num_detour_k != num_detour) { continue; } - const auto candidate_node = knn_graph(i, k); - bool dup = false; - for (uint32_t dk = 0; dk < pk; dk++) { - if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { - dup = true; - break; - } - } - if (!dup && candidate_node < graph_size) { - output_graph_ptr[i * output_graph_degree + pk] = candidate_node; - pk += 1; + const auto candidate_node = knn_graph_view(i, k); + bool dup = false; + for (uint32_t dk = 0; dk < pk; dk++) { + if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { + dup = true; + break; } - if (pk >= output_graph_degree) break; + } + if (!dup && candidate_node < graph_size) { + output_graph_ptr[i * output_graph_degree + pk] = candidate_node; + pk += 1; } if (pk >= output_graph_degree) break; - - if (next_num_detour == std::numeric_limits::max()) { break; } - num_detour = next_num_detour; - } - if (pk != output_graph_degree) { - RAFT_LOG_DEBUG( - "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - i); - invalid_neighbor_list = true; } + if (pk >= output_graph_degree) break; + + if (next_num_detour == std::numeric_limits::max()) { break; } + num_detour = next_num_detour; + } + if (pk != output_graph_degree) { + RAFT_LOG_DEBUG( + "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " + "node %lu in the rank-based node reranking process", + output_graph_degree, + i); + invalid_neighbor_list = true; } - RAFT_EXPECTS( - !invalid_neighbor_list, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); } + RAFT_EXPECTS( + !invalid_neighbor_list, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); +} - const double time_prune_end = cur_time(); - RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); +template +bool is_gpu_accessible(T* ptr) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + return attr.devicePointer != nullptr; } // TODO allow pinned input for both knn_graph and new_graph @@ -1600,7 +1893,7 @@ void optimize( // large temporary memory for large arrays, e.g. everything >= O(graph_size) auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); // temporary memory for small arrays, e.g. everything <= O(batchsize * graph_degree) - // auto tmp_mr = raft::resource::get_tmp_workspace_resource(res); + auto default_ws_mr = raft::resource::get_workspace_resource(res); RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), "Each input array is expected to have the same number of rows"); @@ -1611,409 +1904,174 @@ void optimize( const uint64_t output_graph_degree = new_graph.extent(1); const uint64_t graph_size = new_graph.extent(0); // auto input_graph_ptr = knn_graph.data_handle(); - auto output_graph_ptr = new_graph.data_handle(); raft::common::nvtx::range fun_scope( "cagra::graph::optimize(%zu, %zu, %u)", graph_size, knn_graph_degree, output_graph_degree); - // MST optimization - auto mst_graph = raft::make_pinned_matrix(res, 0, 0); - auto mst_graph_num_edges = raft::make_pinned_vector(res, graph_size); - auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); + // check if input and output are both device accessible + // in this case we assume data to be ONLY device accessible and not host accessible + // furthermore we ensure all large allocations to go to the large workspace resource + // and all small allocations to go to the default workspace resource + bool inout_device_accessible = false; + { + bool input_device_accessible = is_gpu_accessible(knn_graph.data_handle()); + bool output_device_accessible = is_gpu_accessible(new_graph.data_handle()); + RAFT_EXPECTS(input_device_accessible == output_device_accessible, + "Input and output must be either both device accessible or both host accessible"); + inout_device_accessible = input_device_accessible && output_device_accessible; + } + // MST optimization + // currently, only using GPU path for MST optimization + auto p_mst_graph = raft::make_pinned_matrix(res, 0, 0); + auto p_mst_graph_num_edges = raft::make_pinned_vector(res, graph_size); + auto p_mst_graph_num_edges_ptr = p_mst_graph_num_edges.data_handle(); #pragma omp parallel for for (uint64_t i = 0; i < graph_size; i++) { - mst_graph_num_edges_ptr[i] = 0; + p_mst_graph_num_edges_ptr[i] = 0; } if (guarantee_connectivity) { raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_connectivity"); - mst_graph = raft::make_pinned_matrix( + p_mst_graph = raft::make_pinned_matrix( res, graph_size, output_graph_degree); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); - mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); + mst_optimization( + res, knn_graph, p_mst_graph.view(), p_mst_graph_num_edges.view(), use_gpu); for (uint64_t i = 0; i < graph_size; i++) { if (i < 8 || i >= graph_size - 8) { - RAFT_LOG_DEBUG("# mst_graph_num_edges_ptr[%lu]: %u\n", i, mst_graph_num_edges_ptr[i]); + RAFT_LOG_DEBUG("# p_mst_graph_num_edges_ptr[%lu]: %u\n", i, p_mst_graph_num_edges_ptr[i]); } } } - prune_graph(res, knn_graph, new_graph, use_gpu); - - auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); - auto rev_graph_count = raft::make_host_vector(graph_size); - - bool _use_gpu_rev_graph = use_gpu; - // TODO: should we use pinned memory if we have issues fitting on GPU? - if (_use_gpu_rev_graph) { + // prune graph -- will use GPU path if possible, otherwise CPU path + // we only need to check in case input is not alreadydevice accessible + bool use_gpu_prune = use_gpu; + if (!inout_device_accessible) { try { - auto d_rev_graph_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - auto d_dest_nodes = - raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); - auto d_rev_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); - auto d_output_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + auto d_input_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for reverse graph on GPU"); - _use_gpu_rev_graph = false; + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); + use_gpu_prune = false; } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for reverse graph on GPU (logic error)"); - _use_gpu_rev_graph = false; + RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); + use_gpu_prune = false; } } - - const double time_make_start = cur_time(); - if (_use_gpu_rev_graph) { - // - // Make reverse graph on GPU - // - auto d_rev_graph_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - - device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); - device_matrix_view_from_host d_output_graph( + if (use_gpu_prune) { + // should be noop in case input is already device accessible + device_matrix_view_from_host d_input_graph( res, raft::make_host_matrix_view( - output_graph_ptr, graph_size, output_graph_degree)); - - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse"); - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = - raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), - 0xff, - graph_size * output_graph_degree * sizeof(IdxT), - raft::resource::get_cuda_stream(res))); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - raft::resource::get_cuda_stream(res))); - - for (uint64_t k = 0; k < output_graph_degree; k++) { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; - dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; - } - raft::resource::sync_stream(res); - - raft::copy(d_dest_nodes.data_handle(), - dest_nodes.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); - - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree); - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); - } - - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - if (d_rev_graph.allocated_memory()) { - raft::copy(rev_graph.data_handle(), - d_rev_graph.data_handle(), - graph_size * output_graph_degree, - raft::resource::get_cuda_stream(res)); - } - raft::copy(rev_graph_count.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); - - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", - (time_make_end - time_make_start) * 1000.0); - } - - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/combine"); - - // Merging the prunned graph and the reverse graph - const double merge_graph_start = cur_time(); - - // Create a boolean variable on the GPU using RAFT device allocator - auto d_check_num_protected_edges = raft::make_device_scalar(res, true); - - uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - - const dim3 threads_merge(32, 1, 1); - const dim3 blocks_merge(batch_size, 1, 1); - const size_t merge_smem_size = (output_graph_degree + output_graph_degree) * sizeof(IdxT); - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - kern_merge_graph - <<>>( - d_output_graph.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree, - mst_graph.data_handle(), - output_graph_degree, - mst_graph_num_edges_ptr, - batch_size, - i_batch, - guarantee_connectivity, - d_check_num_protected_edges.data_handle()); - } - - bool check_num_protected_edges = true; - raft::copy(&check_num_protected_edges, - d_check_num_protected_edges.data_handle(), - 1, - raft::resource::get_cuda_stream(res)); - - if (d_output_graph.allocated_memory()) { - raft::copy(output_graph_ptr, - d_output_graph.data_handle(), - graph_size * output_graph_degree, - raft::resource::get_cuda_stream(res)); - } - - raft::resource::sync_stream(res); + knn_graph.data_handle(), graph_size, knn_graph_degree)); - const auto merge_graph_end = cur_time(); - RAFT_EXPECTS(check_num_protected_edges, - "Failed to merge the MST, pruned, and reverse edge graphs. " - "Some nodes have too " - "many MST optimization edges."); + prune_graph_gpu(res, + d_input_graph.data_handle(), + graph_size, + knn_graph_degree, + new_graph.data_handle(), + output_graph_degree); - RAFT_LOG_DEBUG("# Time for merging graphs: %.1lf ms", - (merge_graph_end - merge_graph_start) * 1000.0); - } } else { - { - // Make reverse graph on CPU - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse"); - - auto rev_graph_ptr = rev_graph.data_handle(); - auto rev_graph_count_ptr = rev_graph_count.data_handle(); - -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - rev_graph_count_ptr[i] = 0; - } - - for (uint32_t k = 0; k < output_graph_degree; k++) { -#pragma omp parallel for - for (uint64_t src_id = 0; src_id < graph_size; src_id++) { - const IdxT dest_id = - output_graph_ptr[k + (static_cast(output_graph_degree) * src_id)]; - if (dest_id >= graph_size) continue; - uint32_t pos; -#pragma omp atomic capture - pos = rev_graph_count_ptr[dest_id]++; - if (pos < output_graph_degree) { - rev_graph_ptr[(static_cast(output_graph_degree) * dest_id) + pos] = - static_cast(src_id); - } - } - } + prune_graph_cpu(knn_graph.data_handle(), + graph_size, + knn_graph_degree, + new_graph.data_handle(), + output_graph_degree); + } - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time (CPU): %.1lf ms", - (time_make_end - time_make_start) * 1000.0); - } + // reverse graph creation will always use the GPU + auto d_rev_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/combine"); - // - // Create search graphs from MST and pruned and reverse graphs - // - const double time_replace_start = cur_time(); + // This should use the default workspace resource for random access / atomics + auto d_rev_graph_count = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(graph_size)); - bool check_num_protected_edges = true; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - auto my_rev_graph = rev_graph.data_handle() + (output_graph_degree * i); - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); - - // If guarantee_connectivity == true, use a temporal list to merge the neighbor lists of the - // graphs. - std::vector temp_output_neighbor_list; - if (guarantee_connectivity) { - temp_output_neighbor_list.resize(output_graph_degree); - my_out_graph = temp_output_neighbor_list.data(); - const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; - - // Set MST graph edges - for (uint32_t j = 0; j < mst_graph_num_edges; j++) { - my_out_graph[j] = mst_graph(i, j); - } - - // Set pruned graph edges - for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; - (pruned_j < output_graph_degree) && (output_j < output_graph_degree); - pruned_j++) { - const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; - - // duplication check - bool dup = false; - for (uint32_t m = 0; m < output_j; m++) { - if (v == my_out_graph[m]) { - dup = true; - break; - } - } + const double time_make_start = cur_time(); - if (!dup) { - my_out_graph[output_j] = v; - output_j++; - } - } - } + make_reverse_graph_gpu( + res, d_rev_graph.data_handle(), d_rev_graph_count.data_handle(), new_graph); - const auto num_protected_edges = - std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); - if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } - if (num_protected_edges == output_graph_degree) continue; - - // Replace some edges of the output graph with edges of the reverse graph. - auto kr = std::min(rev_graph_count.data_handle()[i], output_graph_degree); - while (kr) { - kr -= 1; - if (my_rev_graph[kr] < graph_size) { - uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos >= output_graph_degree) { - num_shift = output_graph_degree - num_protected_edges - 1; - } - shift_array(my_out_graph + num_protected_edges, num_shift); - my_out_graph[num_protected_edges] = my_rev_graph[kr]; - } - } + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", + (time_make_end - time_make_start) * 1000.0); - // If guarantee_connectivity == true, move the output neighbor list from the temporal list - // to the output list. If false, the copy is not needed because my_out_graph is a pointer to - // the output buffer. - if (guarantee_connectivity) { - for (uint32_t j = 0; j < output_graph_degree; j++) { - output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; - } - } - } - RAFT_EXPECTS(check_num_protected_edges, - "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " - "many MST optimization edges."); - - const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", - (time_replace_end - time_replace_start) * 1000.0); + // merge graph -- will use GPU path if possible, otherwise CPU path + // we only need to check in case output is not already device accessible + bool use_gpu_merge = use_gpu; + if (!inout_device_accessible) { + try { + auto d_new_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); + } catch (std::bad_alloc& e) { + RAFT_LOG_DEBUG("Insufficient memory for merging on GPU"); + use_gpu_merge = false; + } catch (raft::logic_error& e) { + RAFT_LOG_DEBUG("Insufficient memory for merging on GPU (logic error)"); + use_gpu_merge = false; } } - // Check stats - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/stats"); - /* stats */ - uint64_t num_replaced_edges = 0; -#pragma omp parallel for reduction(+ : num_replaced_edges) - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - const uint64_t pos = - pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); - if (pos == output_graph_degree) { num_replaced_edges += 1; } - } + if (use_gpu_merge) { + // should be noop in case output is already device accessible + device_matrix_view_from_host d_new_graph( + res, + raft::make_host_matrix_view( + new_graph.data_handle(), graph_size, output_graph_degree)); + + merge_graph_gpu(res, + d_new_graph.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + p_mst_graph.data_handle(), + p_mst_graph_num_edges.data_handle(), + graph_size, + output_graph_degree, + guarantee_connectivity); + + if (d_new_graph.allocated_memory()) { + raft::copy(new_graph.data_handle(), + d_new_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); } - RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); - } + } else { + auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); + auto rev_graph_count = raft::make_host_vector(graph_size); + auto mst_graph = raft::make_host_matrix(0, 0); + raft::copy(rev_graph.data_handle(), + d_rev_graph.data_handle(), + graph_size * output_graph_degree, + raft::resource::get_cuda_stream(res)); + raft::copy(rev_graph_count.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); - // Check number of incoming edges - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/check_edges"); - auto in_edge_count = raft::make_host_vector(graph_size); - auto in_edge_count_ptr = in_edge_count.data_handle(); -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - in_edge_count_ptr[i] = 0; - } -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - if (j >= graph_size) continue; -#pragma omp atomic - in_edge_count_ptr[j] += 1; - } - } - auto hist = raft::make_host_vector(output_graph_degree); - auto hist_ptr = hist.data_handle(); - for (uint64_t k = 0; k < output_graph_degree; k++) { - hist_ptr[k] = 0; - } -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - uint32_t count = in_edge_count_ptr[i]; - if (count >= output_graph_degree) continue; -#pragma omp atomic - hist_ptr[count] += 1; - } - RAFT_LOG_DEBUG("# Histogram for number of incoming edges\n"); - uint32_t sum_hist = 0; - for (uint64_t k = 0; k < output_graph_degree; k++) { - sum_hist += hist_ptr[k]; - RAFT_LOG_DEBUG("# %3lu, %8u, %lf, (%8u, %lf)\n", - k, - hist_ptr[k], - (double)hist_ptr[k] / graph_size, - sum_hist, - (double)sum_hist / graph_size); - } + merge_graph_cpu(new_graph.data_handle(), + rev_graph.data_handle(), + rev_graph_count.data_handle(), + p_mst_graph.data_handle(), + p_mst_graph_num_edges_ptr, + graph_size, + output_graph_degree, + guarantee_connectivity); } - // Check duplication and out-of-range indices - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/check_duplicates"); - uint64_t num_dup = 0; - uint64_t num_oor = 0; -#pragma omp parallel for reduction(+ : num_dup) reduction(+ : num_oor) - for (uint64_t i = 0; i < graph_size; i++) { - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); - for (uint32_t j = 0; j < output_graph_degree; j++) { - const auto neighbor_a = my_out_graph[j]; + if (!inout_device_accessible) { + // following checks require host access + log_replaced_edges_stats(new_graph.data_handle(), graph_size, output_graph_degree); - // Check oor - if (neighbor_a > graph_size) { - num_oor++; - continue; - } + log_incoming_edges_histogram(new_graph.data_handle(), graph_size, output_graph_degree); - // Check duplication - for (uint32_t k = j + 1; k < output_graph_degree; k++) { - const auto neighbor_b = my_out_graph[k]; - if (neighbor_a == neighbor_b) { num_dup++; } - } - } - } - RAFT_EXPECTS( - num_dup == 0, "%lu duplicated node(s) are found in the generated CAGRA graph", num_dup); - RAFT_EXPECTS(num_oor == 0, - "%lu out-of-range index node(s) are found in the generated CAGRA graph", - num_oor); + check_duplicates_and_out_of_range( + new_graph.data_handle(), graph_size, output_graph_degree); + } else { + RAFT_LOG_DEBUG("Output graph is on GPU, skipping checks"); } } From 5e9ebc53950e472c8ee0035f280905dc5b1984b5 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 2 Mar 2026 17:34:17 +0000 Subject: [PATCH 08/22] enable both host/device inout graphs for optimize --- .../neighbors/detail/cagra/cagra_build.cuh | 23 ++--- cpp/src/neighbors/detail/cagra/graph_core.cuh | 97 +++++++++++-------- cpp/src/neighbors/detail/cagra/utils.hpp | 18 +++- 3 files changed, 86 insertions(+), 52 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index a1c16250c5..009362aa96 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -822,8 +822,6 @@ inline std::pair optimize_workspace_size(size_t n_rows, size_t index_size, bool mst_optimize = false) { - // TODO: MODIFY!! - // MST optimization memory (host only) size_t mst_host = n_rows * index_size; // mst_graph_num_edges if (mst_optimize) { @@ -835,27 +833,26 @@ inline std::pair optimize_workspace_size(size_t n_rows, // Prune stage memory // We neglect 8 bytes (both on host and device) for stats - size_t prune_host = n_rows * intermediate_degree * sizeof(uint8_t); // detour count + size_t batch_size = std::min(static_cast(256 * 1024), n_rows); - size_t prune_dev = n_rows * intermediate_degree * 1; // detour count (uint8_t) - prune_dev += n_rows * sizeof(uint32_t); // d_num_detour_edges - prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph + size_t prune_dev = batch_size * intermediate_degree * 1; // detour count (uint8_t) + prune_dev += batch_size * sizeof(uint32_t); // d_num_detour_edges + prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph // Reverse graph stage memory - size_t rev_host = n_rows * graph_degree * index_size; // rev_graph - rev_host += n_rows * sizeof(uint32_t); // rev_graph_count - rev_host += n_rows * index_size; // dest_nodes - size_t rev_dev = n_rows * graph_degree * index_size; // d_rev_graph rev_dev += n_rows * sizeof(uint32_t); // d_rev_graph_count rev_dev += n_rows * sizeof(uint32_t); // d_dest_nodes - // Memory for merging graphs (host only) + // Memory for merging graphs (host only optional) size_t combine_host = n_rows * sizeof(uint32_t) + graph_degree * sizeof(uint32_t); // in_edge_count + hist - size_t total_host = mst_host + std::max({prune_host, rev_host, combine_host}); - size_t total_dev = std::max(prune_dev, rev_dev); + // additional memory for combine stage on device + size_t combine_dev = n_rows * graph_degree * index_size; // d_output_graph + + size_t total_host = mst_host + combine_host; + size_t total_dev = std::max(prune_dev, rev_dev + combine_dev); return std::make_pair(total_host, total_dev); } diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 713b03ca20..96110c9613 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -246,6 +246,26 @@ __global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_ } } +template +__global__ void kern_make_rev_graph_k(const IdxT* const dest_nodes, // [graph_size] + IdxT* const rev_graph, // [size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree, + uint64_t k) +{ + const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint64_t tnum = blockDim.x * gridDim.x; + + for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { + IdxT dest_id = dest_nodes[k + (degree * src_id)]; + if (dest_id >= graph_size) continue; + + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { rev_graph[(degree * dest_id) + pos] = static_cast(src_id); } + } +} + // Based on the detour count, select the smallest detour count and its index // (Pruning Update Kernel) template @@ -932,11 +952,11 @@ void merge_graph_cpu(IdxT* output_graph_ptr, (time_replace_end - time_replace_start) * 1000.0); } -template +template void make_reverse_graph_gpu(raft::resources const& res, IdxT* d_rev_graph, uint32_t* d_rev_graph_count, - raft::host_matrix_view new_graph) + InOutMatrixView new_graph) { const uint64_t graph_size = new_graph.extent(0); const uint64_t output_graph_degree = new_graph.extent(1); @@ -958,26 +978,38 @@ void make_reverse_graph_gpu(raft::resources const& res, RAFT_CUDA_TRY(cudaMemsetAsync( d_rev_graph_count, 0x00, graph_size * sizeof(uint32_t), raft::resource::get_cuda_stream(res))); + bool output_graph_device_accessible = is_ptr_device_accessible(output_graph_ptr); + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + for (uint64_t k = 0; k < output_graph_degree; k++) { + if (output_graph_device_accessible) { + kern_make_rev_graph_k<<>>( + output_graph_ptr, + d_rev_graph, + d_rev_graph_count, + static_cast(graph_size), + static_cast(output_graph_degree), + k); + } else { #pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; - } - raft::resource::sync_stream(res); + for (uint64_t i = 0; i < graph_size; i++) { + dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; + } + raft::resource::sync_stream(res); - raft::copy(d_dest_nodes.data_handle(), - dest_nodes.data_handle(), - graph_size, - raft::resource::get_cuda_stream(res)); + raft::copy(d_dest_nodes.data_handle(), + dest_nodes.data_handle(), + graph_size, + raft::resource::get_cuda_stream(res)); - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph, - d_rev_graph_count, - static_cast(graph_size), - static_cast(output_graph_degree)); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph, + d_rev_graph_count, + static_cast(graph_size), + static_cast(output_graph_degree)); + } RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %lu \r", k, output_graph_degree); } @@ -1868,24 +1900,13 @@ void prune_graph_cpu(IdxT* knn_graph_ptr, "overflows occur during the norm computation between the dataset vectors."); } -template -bool is_gpu_accessible(T* ptr) -{ - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); - return attr.devicePointer != nullptr; -} - // TODO allow pinned input for both knn_graph and new_graph -template , raft::memory_type::host>> -void optimize( - raft::resources const& res, - raft::mdspan, raft::row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph, - const bool guarantee_connectivity = true, - const bool use_gpu = true) +template +void optimize(raft::resources const& res, + InOutMatrixView knn_graph, + InOutMatrixView new_graph, + const bool guarantee_connectivity = true, + const bool use_gpu = true) { RAFT_LOG_DEBUG( "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); @@ -1913,8 +1934,8 @@ void optimize( // and all small allocations to go to the default workspace resource bool inout_device_accessible = false; { - bool input_device_accessible = is_gpu_accessible(knn_graph.data_handle()); - bool output_device_accessible = is_gpu_accessible(new_graph.data_handle()); + bool input_device_accessible = is_ptr_device_accessible(knn_graph.data_handle()); + bool output_device_accessible = is_ptr_device_accessible(new_graph.data_handle()); RAFT_EXPECTS(input_device_accessible == output_device_accessible, "Input and output must be either both device accessible or both host accessible"); inout_device_accessible = input_device_accessible && output_device_accessible; @@ -2062,7 +2083,7 @@ void optimize( guarantee_connectivity); } - if (!inout_device_accessible) { + if (is_ptr_host_accessible(new_graph.data_handle())) { // following checks require host access log_replaced_edges_stats(new_graph.data_handle(), graph_size, output_graph_degree); diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 30c7287430..7889d6d9a9 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -152,6 +152,22 @@ struct gen_index_msb_1_mask { }; } // namespace utils +template +bool is_ptr_device_accessible(T* ptr) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + return attr.devicePointer != nullptr; +} + +template +bool is_ptr_host_accessible(T* ptr) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + return attr.hostPointer != nullptr; +} + /** * Utility to sync memory from a host_matrix_view to a device_matrix_view * From 40977e2e456f2fd9ee32413be3590acfe2e7bdd4 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 2 Mar 2026 23:35:32 +0000 Subject: [PATCH 09/22] smaller fixes --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 6c0d6c747c..f77f6367e5 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -801,8 +801,8 @@ void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, template void merge_graph_gpu(raft::resources const& res, IdxT* output_graph_ptr, - const IdxT* d_rev_graph, - uint32_t* d_rev_graph_count, + const IdxT* d_rev_graph_ptr, + uint32_t* d_rev_graph_count_ptr, const IdxT* mst_graph_ptr, const uint32_t* mst_graph_num_edges_ptr, uint64_t graph_size, @@ -831,8 +831,8 @@ void merge_graph_gpu(raft::resources const& res, kern_merge_graph <<>>( d_output_graph.data_handle(), - d_rev_graph, - d_rev_graph_count, + d_rev_graph_ptr, + d_rev_graph_count_ptr, static_cast(graph_size), static_cast(output_graph_degree), mst_graph_ptr, @@ -955,8 +955,8 @@ void merge_graph_cpu(IdxT* output_graph_ptr, template void make_reverse_graph_gpu(raft::resources const& res, - IdxT* d_rev_graph, - uint32_t* d_rev_graph_count, + IdxT* d_rev_graph_ptr, + uint32_t* d_rev_graph_count_ptr, InOutMatrixView new_graph) { const uint64_t graph_size = new_graph.extent(0); @@ -966,18 +966,19 @@ void make_reverse_graph_gpu(raft::resources const& res, raft::common::nvtx::range block_scope( "cagra::graph::optimize/reverse"); - auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = - raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); + auto d_dest_nodes = raft::make_device_mdarray( + res, raft::resource::get_workspace_resource(res), raft::make_extents(graph_size)); raft::matrix::fill( res, - raft::make_device_vector_view(d_rev_graph, graph_size * output_graph_degree), + raft::make_device_vector_view(d_rev_graph_ptr, graph_size * output_graph_degree), IdxT(-1)); raft::matrix::fill( - res, raft::make_device_vector_view(d_rev_graph_count, graph_size), uint32_t(0)); + res, + raft::make_device_vector_view(d_rev_graph_count_ptr, graph_size), + uint32_t(0)); bool output_graph_device_accessible = is_ptr_device_accessible(output_graph_ptr); dim3 threads(256, 1, 1); @@ -987,8 +988,8 @@ void make_reverse_graph_gpu(raft::resources const& res, if (output_graph_device_accessible) { kern_make_rev_graph_k<<>>( output_graph_ptr, - d_rev_graph, - d_rev_graph_count, + d_rev_graph_ptr, + d_rev_graph_count_ptr, static_cast(graph_size), static_cast(output_graph_degree), k); @@ -1003,8 +1004,8 @@ void make_reverse_graph_gpu(raft::resources const& res, kern_make_rev_graph<<>>( d_dest_nodes.data_handle(), - d_rev_graph, - d_rev_graph_count, + d_rev_graph_ptr, + d_rev_graph_count_ptr, static_cast(graph_size), static_cast(output_graph_degree)); } @@ -1679,6 +1680,8 @@ void prune_graph_gpu(raft::resources const& res, IdxT* output_graph_ptr, uint64_t output_graph_degree) { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune"); auto default_ws_mr = raft::resource::get_workspace_resource(res); uint32_t batch_size = @@ -1715,6 +1718,8 @@ void prune_graph_gpu(raft::resources const& res, res, default_ws_mr, raft::make_extents(batch_size, output_graph_degree)); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { raft::matrix::fill(res, d_detour_count.view(), uint8_t(0xff)); raft::matrix::fill(res, d_num_no_detour_edges.view(), uint32_t(0)); @@ -1744,18 +1749,21 @@ void prune_graph_gpu(raft::resources const& res, knn_graph_degree, output_graph_degree, d_detour_count.data_handle(), - d_output_graph.data_handle(), + output_device_accessible ? d_output_graph.data_handle() + : output_graph_ptr + i_batch * batch_size * output_graph_degree, batch_size, i_batch, d_invalid_neighbor_list.data_handle()); - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, - d_output_graph.data_handle(), - copy_size, - raft::resource::get_cuda_stream(res)); + if (!output_device_accessible) { + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; + raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, + d_output_graph.data_handle(), + copy_size, + raft::resource::get_cuda_stream(res)); + } raft::resource::sync_stream(res); RAFT_LOG_DEBUG( @@ -1799,14 +1807,14 @@ void prune_graph_cpu(IdxT* knn_graph_ptr, IdxT* output_graph_ptr, uint64_t output_graph_degree) { + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune"); auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); auto knn_graph_view = raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree); { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); const double time_2hop_count_start = cur_time(); count_2hop_detours(knn_graph_view, detour_count.view()); @@ -2022,7 +2030,6 @@ void optimize(raft::resources const& res, } else { auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); auto rev_graph_count = raft::make_host_vector(graph_size); - auto mst_graph = raft::make_host_matrix(0, 0); raft::copy(res, rev_graph.view(), d_rev_graph.view()); raft::copy(res, rev_graph_count.view(), d_rev_graph_count.view()); @@ -2036,6 +2043,8 @@ void optimize(raft::resources const& res, guarantee_connectivity); } + raft::resource::sync_stream(res); + if (is_ptr_host_accessible(new_graph.data_handle())) { // following checks require host access log_replaced_edges_stats(new_graph.data_handle(), graph_size, output_graph_degree); From 14e9f3ebc94aec7031d5c8eb685dc9b6fb36595d Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 3 Mar 2026 12:41:33 +0000 Subject: [PATCH 10/22] bugfix --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index f77f6367e5..5b2893e77f 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -1749,8 +1749,8 @@ void prune_graph_gpu(raft::resources const& res, knn_graph_degree, output_graph_degree, d_detour_count.data_handle(), - output_device_accessible ? d_output_graph.data_handle() - : output_graph_ptr + i_batch * batch_size * output_graph_degree, + output_device_accessible ? output_graph_ptr + i_batch * batch_size * output_graph_degree + : d_output_graph.data_handle(), batch_size, i_batch, d_invalid_neighbor_list.data_handle()); From 416558d40b1207ca7f1b8aad0aeda68b24e68aea Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 5 Mar 2026 21:43:49 +0000 Subject: [PATCH 11/22] fuse and simplify pruning, remove CPU path --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 331 +++++------------- 1 file changed, 92 insertions(+), 239 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 5b2893e77f..25be6ae393 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -157,79 +157,6 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, } } -template -__global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - const uint32_t graph_size, - const uint32_t graph_degree, - const uint32_t degree, - const uint32_t batch_size, - const uint32_t batch_id, - uint8_t* const detour_count, // [batch_size, graph_degree] - uint32_t* const num_no_detour_edges, // [batch_size] - uint64_t* const stats) -{ - __shared__ uint32_t smem_num_detour[MAX_DEGREE]; - extern __shared__ unsigned char smem_buf[]; - IdxT* const smem_knn_iA_neighbors = reinterpret_cast(smem_buf); - - uint64_t* const num_retain = stats; - uint64_t* const num_full = stats + 1; - - const uint64_t iA = blockIdx.x + (batch_size * batch_id); - const uint64_t iA_batch = blockIdx.x; - - if (iA >= graph_size) { return; } - - // Load this node's neighbor row into shared memory to reduce global reads - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - smem_num_detour[k] = 0; - smem_knn_iA_neighbors[k] = knn_graph[k + ((uint64_t)graph_degree * iA)]; - if (smem_knn_iA_neighbors[k] == iA) { - // Lower the priority of self-edge - smem_num_detour[k] = graph_degree; - } - } - __syncthreads(); - - // count number of detours (A->D->B) - for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { - const uint64_t iD = smem_knn_iA_neighbors[kAD]; - if (iD >= graph_size) { continue; } - for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { - const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; - for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { - // if ( kDB < kAB ) - { - const uint64_t iB = smem_knn_iA_neighbors[kAB]; - if (iB == iB_candidate) { - atomicAdd(smem_num_detour + kAB, 1); - break; - } - } - } - } - __syncthreads(); - } - - uint32_t num_edges_no_detour = 0; - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - detour_count[k + (graph_degree * iA_batch)] = min(smem_num_detour[k], (uint32_t)255); - if (smem_num_detour[k] == 0) { num_edges_no_detour++; } - } - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); - num_edges_no_detour = min(num_edges_no_detour, degree); - - if (threadIdx.x == 0) { - num_no_detour_edges[iA_batch] = num_edges_no_detour; - atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); - if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } - } -} - template __global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] IdxT* const rev_graph, // [size, degree] @@ -269,48 +196,98 @@ __global__ void kern_make_rev_graph_k(const IdxT* const dest_nodes, // [grap } } -// Based on the detour count, select the smallest detour count and its index -// (Pruning Update Kernel) -template -__global__ void kern_select_smallest_detour_neighbors( - const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - uint64_t graph_size, - uint64_t knn_graph_degree, - uint64_t output_graph_degree, - uint8_t* const d_detour_count, // [batch_size, graph_degree] - IdxT* output_graph_ptr, - const uint32_t batch_size, // [batch_size, output_graph_degree] - const uint32_t batch_id, - uint32_t* const d_invalid_neighbor_list) +template +__global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] + IdxT* const output_graph_ptr, + const uint32_t graph_size, + const uint32_t knn_graph_degree, + const uint32_t output_graph_degree, + const uint32_t batch_size, + const uint32_t batch_id, + uint32_t* const d_invalid_neighbor_list, + uint64_t* const stats) { - assert(blockDim.x == 32); + extern __shared__ unsigned char smem_buf[]; - // Allocate shared memory for detour counts and their indices - extern __shared__ IdxT smem_indices[]; - uint16_t* smem_detour_count = (uint16_t*)&smem_indices[knn_graph_degree]; + const uint32_t wid = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; - const uint64_t nid = blockIdx.x + (batch_size * batch_id); - const uint64_t nid_batch = blockIdx.x; + IdxT* const smem_indices = + reinterpret_cast(smem_buf + wid * knn_graph_degree * sizeof(IdxT)); + uint32_t* const smem_num_detour = reinterpret_cast( + smem_buf + wid * knn_graph_degree * sizeof(IdxT) + num_warps * knn_graph_degree * sizeof(IdxT)); + + uint64_t* const num_retain = stats; + uint64_t* const num_full = stats + 1; + + const unsigned warp_mask = 0xffffffff; const uint32_t maxval16 = 0x0000ffff; + const uint64_t nid_batch = blockIdx.x * num_warps + wid; + const uint64_t nid = nid_batch + (batch_size * batch_id); + if (nid >= graph_size) { return; } - // Load indices and detour counts for each neighbor; invalidate out-of-bounds entries - for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { - smem_indices[k] = knn_graph[knn_graph_degree * nid + k]; - smem_detour_count[k] = (smem_indices[k] >= graph_size) - ? maxval16 - : (uint16_t)d_detour_count[nid_batch * knn_graph_degree + k]; + // Load this node's neighbor row into shared memory to reduce global reads + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + smem_num_detour[k] = 0; + smem_indices[k] = knn_graph[k + ((uint64_t)knn_graph_degree * nid)]; + if (smem_indices[k] == nid) { + // Lower the priority of self-edge + smem_num_detour[k] = knn_graph_degree; + } } __syncwarp(); - const unsigned warp_mask = 0xffffffff; + // count number of detours (A->D->B) + for (uint32_t kAD = 0; kAD < knn_graph_degree - 1; kAD++) { + const uint64_t iD = smem_indices[kAD]; + if (iD >= graph_size) { continue; } + for (uint32_t kDB = lane_id; kDB < knn_graph_degree; kDB += raft::WarpSize) { + const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)knn_graph_degree * iD)]; + for (uint32_t kAB = kAD + 1; kAB < knn_graph_degree; kAB++) { + // if ( kDB < kAB ) + { + const uint64_t iB = smem_indices[kAB]; + if (iB == iB_candidate) { + atomicAdd(smem_num_detour + kAB, 1); + break; + } + } + } + } + __syncwarp(); + } + + uint32_t num_edges_no_detour = 0; + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + smem_num_detour[k] = min(smem_num_detour[k], maxval16); + if (smem_num_detour[k] == 0) { num_edges_no_detour++; } + if (smem_indices[k] >= graph_size) { smem_num_detour[k] = maxval16; } + } + + __syncwarp(); + + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); + num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); + num_edges_no_detour = min(num_edges_no_detour, output_graph_degree); + + if (lane_id == 0) { + atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); + if (num_edges_no_detour >= output_graph_degree) { + atomicAdd((unsigned long long int*)num_full, 1); + } + } + for (uint32_t i = 0; i < output_graph_degree; i++) { uint32_t local_min = maxval16; uint32_t local_idx = maxval16; - for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { - if (smem_detour_count[k] < local_min) { - local_min = smem_detour_count[k]; + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + if (smem_num_detour[k] < local_min) { + local_min = smem_num_detour[k]; local_idx = k; } } @@ -321,18 +298,18 @@ __global__ void kern_select_smallest_detour_neighbors( uint32_t warp_local_idx = warp_min_with_tag & 0xffff; if (warp_min_count == maxval16 || warp_local_idx == maxval16) { - if (threadIdx.x == 0) { atomicExch(d_invalid_neighbor_list, 1u); } + if (lane_id == 0) { atomicExch(d_invalid_neighbor_list, 1u); } break; } IdxT selected_node = smem_indices[warp_local_idx]; - for (uint32_t k = threadIdx.x; k < knn_graph_degree; k += blockDim.x) { - if (smem_indices[k] == selected_node) { smem_detour_count[k] = maxval16; } + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + if (smem_indices[k] == selected_node) { smem_num_detour[k] = maxval16; } } __syncwarp(warp_mask); - if (threadIdx.x == 0) { output_graph_ptr[nid_batch * output_graph_degree + i] = selected_node; } + if (lane_id == 0) { output_graph_ptr[nid_batch * output_graph_degree + i] = selected_node; } } } @@ -1690,15 +1667,6 @@ void prune_graph_gpu(raft::resources const& res, RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - constexpr int MAX_DEGREE = 1024; - if (knn_graph_degree > MAX_DEGREE) { - RAFT_FAIL( - "The degree of input knn graph is too large (%zu). " - "It must be equal to or smaller than %d.", - knn_graph_degree, - MAX_DEGREE); - } - const double prune_start = cur_time(); uint64_t num_keep __attribute__((unused)) = 0; @@ -1710,10 +1678,6 @@ void prune_graph_gpu(raft::resources const& res, device_matrix_view_from_host d_input_graph( res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree)); - auto d_detour_count = raft::make_device_mdarray( - res, default_ws_mr, raft::make_extents(batch_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, default_ws_mr, raft::make_extents(batch_size)); auto d_output_graph = raft::make_device_mdarray( res, default_ws_mr, raft::make_extents(batch_size, output_graph_degree)); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); @@ -1721,40 +1685,23 @@ void prune_graph_gpu(raft::resources const& res, bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - raft::matrix::fill(res, d_detour_count.view(), uint8_t(0xff)); - raft::matrix::fill(res, d_num_no_detour_edges.view(), uint32_t(0)); - - const dim3 threads_prune(32, 1, 1); + const uint32_t num_warps = 4; + const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); const dim3 blocks_prune(batch_size, 1, 1); - const size_t prune_smem_size = knn_graph_degree * sizeof(IdxT); - kern_prune + const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); + kern_fused_prune <<>>( d_input_graph.data_handle(), + output_device_accessible ? output_graph_ptr + i_batch * batch_size * output_graph_degree + : d_output_graph.data_handle(), graph_size, knn_graph_degree, output_graph_degree, batch_size, i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), + d_invalid_neighbor_list.data_handle(), dev_stats.data_handle()); - const size_t select_smem_size = (knn_graph_degree) * (sizeof(uint16_t) + sizeof(IdxT)); - const dim3 threads_select(32, 1, 1); - const dim3 blocks_select(batch_size, 1, 1); - kern_select_smallest_detour_neighbors - <<>>( - d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - d_detour_count.data_handle(), - output_device_accessible ? output_graph_ptr + i_batch * batch_size * output_graph_degree - : d_output_graph.data_handle(), - batch_size, - i_batch, - d_invalid_neighbor_list.data_handle()); - if (!output_device_accessible) { size_t copy_size = std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * @@ -1800,79 +1747,6 @@ void prune_graph_gpu(raft::resources const& res, (double)num_full / graph_size * 100); } -template -void prune_graph_cpu(IdxT* knn_graph_ptr, - uint64_t graph_size, - uint64_t knn_graph_degree, - IdxT* output_graph_ptr, - uint64_t output_graph_degree) -{ - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune"); - auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - - auto knn_graph_view = - raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree); - - { - const double time_2hop_count_start = cur_time(); - - count_2hop_detours(knn_graph_view, detour_count.view()); - - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", - time_2hop_count_end - time_2hop_count_start); - } - bool invalid_neighbor_list = false; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - uint64_t pk = 0; - uint32_t num_detour = 0; - for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < knn_graph_degree; k++) { - const auto num_detour_k = detour_count(i, k); - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } - - if (num_detour_k != num_detour) { continue; } - - const auto candidate_node = knn_graph_view(i, k); - bool dup = false; - for (uint32_t dk = 0; dk < pk; dk++) { - if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { - dup = true; - break; - } - } - if (!dup && candidate_node < graph_size) { - output_graph_ptr[i * output_graph_degree + pk] = candidate_node; - pk += 1; - } - if (pk >= output_graph_degree) break; - } - if (pk >= output_graph_degree) break; - - if (next_num_detour == std::numeric_limits::max()) { break; } - num_detour = next_num_detour; - } - if (pk != output_graph_degree) { - RAFT_LOG_DEBUG( - "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - i); - invalid_neighbor_list = true; - } - } - RAFT_EXPECTS( - !invalid_neighbor_list, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); -} - // TODO allow pinned input for both knn_graph and new_graph template void optimize(raft::resources const& res, @@ -1939,22 +1813,8 @@ void optimize(raft::resources const& res, } } - // prune graph -- will use GPU path if possible, otherwise CPU path - // we only need to check in case input is not alreadydevice accessible - bool use_gpu_prune = use_gpu; - if (!inout_device_accessible) { - try { - auto d_input_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU"); - use_gpu_prune = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for pruning on GPU (logic error)"); - use_gpu_prune = false; - } - } - if (use_gpu_prune) { + // prune graph -- will always use GPU path + { // should be noop in case input is already device accessible device_matrix_view_from_host d_input_graph( res, @@ -1967,13 +1827,6 @@ void optimize(raft::resources const& res, knn_graph_degree, new_graph.data_handle(), output_graph_degree); - - } else { - prune_graph_cpu(knn_graph.data_handle(), - graph_size, - knn_graph_degree, - new_graph.data_handle(), - output_graph_degree); } // reverse graph creation will always use the GPU From d8d8bd877db9596720efaf67bb1373084dbf17c8 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Thu, 5 Mar 2026 22:49:16 +0000 Subject: [PATCH 12/22] cleanup merge, remove CPU path --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 339 +++++------------- 1 file changed, 85 insertions(+), 254 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 25be6ae393..392edc97d9 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -197,8 +197,8 @@ __global__ void kern_make_rev_graph_k(const IdxT* const dest_nodes, // [grap } template -__global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - IdxT* const output_graph_ptr, +__global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] + IdxT* const output_graph_ptr, // [batch_size, output_graph_degree] const uint32_t graph_size, const uint32_t knn_graph_degree, const uint32_t output_graph_degree, @@ -337,8 +337,8 @@ __device__ void thread_shift_array(T* array, uint64_t num) } } -template -__global__ void kern_merge_graph(IdxT* output_graph, +template +__global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, output_graph_degree] const IdxT* const rev_graph, uint32_t* const rev_graph_count, // [graph_size] const uint32_t graph_size, @@ -352,29 +352,32 @@ __global__ void kern_merge_graph(IdxT* output_graph, bool* check_num_protected_edges) { extern __shared__ unsigned char smem_buf[]; - IdxT* smem_sorted_output_graph = reinterpret_cast(smem_buf); - assert(blockDim.x == 32); + const uint32_t wid = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; - const uint64_t nid = blockIdx.x + (batch_size * batch_id); - if (nid >= graph_size) { return; } + IdxT* smem_sorted_output_graph = + reinterpret_cast(smem_buf + wid * output_graph_degree * sizeof(IdxT)); + + const uint64_t nid_batch = blockIdx.x * num_warps + wid; + const uint64_t nid = nid_batch + (batch_size * batch_id); - if (threadIdx.x == 0) check_num_protected_edges[0] = true; + if (nid >= graph_size) { return; } - const auto mst_graph_num_edges = mst_graph_num_edges_ptr[nid]; + const auto mst_graph_num_edges = guarantee_connectivity ? mst_graph_num_edges_ptr[nid] : 0; // If guarantee_connectivity == true, use a temporal list to merge the // neighbor lists of the graphs. if (guarantee_connectivity) { - for (uint32_t i = threadIdx.x; i < mst_graph_degree; i += 32) { + for (uint32_t i = lane_id; i < mst_graph_degree; i += raft::WarpSize) { smem_sorted_output_graph[i] = mst_graph[nid * mst_graph_degree + i]; } __syncwarp(); for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; (pruned_j < output_graph_degree) && (output_j < output_graph_degree); pruned_j++) { - const auto v = output_graph[output_graph_degree * nid + pruned_j]; + const auto v = output_graph[output_graph_degree * nid_batch + pruned_j]; unsigned int dup = 0; - for (uint32_t m = threadIdx.x; m < output_j; m += 32) { + for (uint32_t m = lane_id; m < output_j; m += raft::WarpSize) { if (v == smem_sorted_output_graph[m]) { dup = 1; break; @@ -383,7 +386,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, unsigned int warp_dup = __ballot_sync(0xffffffff, dup); if (warp_dup == 0) { - if (threadIdx.x == 0) smem_sorted_output_graph[output_j] = v; + if (lane_id == 0) smem_sorted_output_graph[output_j] = v; output_j++; } __syncwarp(); @@ -391,8 +394,8 @@ __global__ void kern_merge_graph(IdxT* output_graph, } else { - for (uint32_t i = threadIdx.x; i < output_graph_degree; i += 32) { - smem_sorted_output_graph[i] = output_graph[output_graph_degree * nid + i]; + for (uint32_t i = lane_id; i < output_graph_degree; i += raft::WarpSize) { + smem_sorted_output_graph[i] = output_graph[output_graph_degree * nid_batch + i]; } __syncwarp(); } @@ -412,7 +415,7 @@ __global__ void kern_merge_graph(IdxT* output_graph, if (pos < num_protected_edges) { continue; } uint64_t num_shift = pos - num_protected_edges; if (pos >= output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } - if (threadIdx.x == 0) { + if (lane_id == 0) { thread_shift_array(smem_sorted_output_graph + num_protected_edges, num_shift); smem_sorted_output_graph[num_protected_edges] = rev_graph[kr + (output_graph_degree * nid)]; } @@ -420,8 +423,8 @@ __global__ void kern_merge_graph(IdxT* output_graph, } } - for (uint32_t i = threadIdx.x; i < output_graph_degree; i += 32) { - output_graph[(output_graph_degree * nid) + i] = smem_sorted_output_graph[i]; + for (uint32_t i = lane_id; i < output_graph_degree; i += raft::WarpSize) { + output_graph[(output_graph_degree * nid_batch) + i] = smem_sorted_output_graph[i]; } } @@ -780,8 +783,8 @@ void merge_graph_gpu(raft::resources const& res, IdxT* output_graph_ptr, const IdxT* d_rev_graph_ptr, uint32_t* d_rev_graph_count_ptr, - const IdxT* mst_graph_ptr, - const uint32_t* mst_graph_num_edges_ptr, + IdxT* mst_graph_ptr, + uint32_t* mst_graph_num_edges_ptr, uint64_t graph_size, uint64_t output_graph_degree, bool guarantee_connectivity) @@ -789,36 +792,62 @@ void merge_graph_gpu(raft::resources const& res, raft::common::nvtx::range block_scope( "cagra::graph::optimize/combine"); + auto default_ws_mr = raft::resource::get_workspace_resource(res); const double merge_graph_start = cur_time(); - device_matrix_view_from_host d_output_graph( - res, - raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree)); - auto d_check_num_protected_edges = raft::make_device_scalar(res, true); + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); uint32_t batch_size = std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - const dim3 threads_merge(32, 1, 1); - const dim3 blocks_merge(batch_size, 1, 1); - const size_t merge_smem_size = (output_graph_degree + output_graph_degree) * sizeof(IdxT); + bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); + auto d_output_graph = raft::make_device_mdarray( + res, + default_ws_mr, + raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); + + device_matrix_view_from_host d_mst_graph( + res, + raft::make_host_matrix_view( + mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree)); + + device_matrix_view_from_host d_mst_graph_num_edges( + res, + raft::make_host_matrix_view( + mst_graph_num_edges_ptr, guarantee_connectivity ? graph_size : 0, 1)); + + const uint32_t num_warps = 4; + const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); + const dim3 blocks_merge(batch_size / num_warps, 1, 1); + const size_t merge_smem_size = num_warps * output_graph_degree * sizeof(IdxT); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - kern_merge_graph + kern_merge_graph <<>>( - d_output_graph.data_handle(), + output_device_accessible ? output_graph_ptr + (i_batch * batch_size * output_graph_degree) + : d_output_graph.data_handle(), d_rev_graph_ptr, d_rev_graph_count_ptr, static_cast(graph_size), static_cast(output_graph_degree), - mst_graph_ptr, + d_mst_graph.data_handle(), static_cast(output_graph_degree), - mst_graph_num_edges_ptr, + d_mst_graph_num_edges.data_handle(), batch_size, i_batch, guarantee_connectivity, d_check_num_protected_edges.data_handle()); + + if (!output_device_accessible) { + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; + raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, + d_output_graph.data_handle(), + copy_size, + raft::resource::get_cuda_stream(res)); + } } bool check_num_protected_edges = true; @@ -827,13 +856,6 @@ void merge_graph_gpu(raft::resources const& res, 1, raft::resource::get_cuda_stream(res)); - if (d_output_graph.allocated_memory()) { - raft::copy( - res, - raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), - d_output_graph.view()); - } - const auto merge_graph_end = cur_time(); RAFT_EXPECTS(check_num_protected_edges, "Failed to merge the MST, pruned, and reverse edge graphs. " @@ -844,92 +866,6 @@ void merge_graph_gpu(raft::resources const& res, (merge_graph_end - merge_graph_start) * 1000.0); } -template -void merge_graph_cpu(IdxT* output_graph_ptr, - const IdxT* rev_graph_ptr, - const uint32_t* rev_graph_count_ptr, - const IdxT* mst_graph_ptr, - const uint32_t* mst_graph_num_edges_ptr, - uint64_t graph_size, - uint64_t output_graph_degree, - bool guarantee_connectivity) -{ - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/combine"); - - const double time_replace_start = cur_time(); - - bool check_num_protected_edges = true; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - auto my_rev_graph = rev_graph_ptr + (output_graph_degree * i); - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); - - std::vector temp_output_neighbor_list; - if (guarantee_connectivity) { - temp_output_neighbor_list.resize(output_graph_degree); - my_out_graph = temp_output_neighbor_list.data(); - const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; - - for (uint32_t j = 0; j < mst_graph_num_edges; j++) { - my_out_graph[j] = mst_graph_ptr[i * output_graph_degree + j]; - } - - for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; - (pruned_j < output_graph_degree) && (output_j < output_graph_degree); - pruned_j++) { - const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; - - bool dup = false; - for (uint32_t m = 0; m < output_j; m++) { - if (v == my_out_graph[m]) { - dup = true; - break; - } - } - - if (!dup) { - my_out_graph[output_j] = v; - output_j++; - } - } - } - - const auto num_protected_edges = - std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); - if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } - if (num_protected_edges == output_graph_degree) continue; - - auto kr = std::min(rev_graph_count_ptr[i], output_graph_degree); - while (kr) { - kr -= 1; - if (my_rev_graph[kr] < graph_size) { - uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos >= output_graph_degree) { - num_shift = output_graph_degree - num_protected_edges - 1; - } - shift_array(my_out_graph + num_protected_edges, num_shift); - my_out_graph[num_protected_edges] = my_rev_graph[kr]; - } - } - - if (guarantee_connectivity) { - for (uint32_t j = 0; j < output_graph_degree; j++) { - output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; - } - } - } - RAFT_EXPECTS(check_num_protected_edges, - "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " - "many MST optimization edges."); - - const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", - (time_replace_end - time_replace_start) * 1000.0); -} - template void make_reverse_graph_gpu(raft::resources const& res, IdxT* d_rev_graph_ptr, @@ -1585,58 +1521,6 @@ void mst_optimization(raft::resources const& res, RAFT_LOG_DEBUG("# MST optimization time: %.1lf sec", time_mst_opt_end - time_mst_opt_start); } -template -void count_2hop_detours(raft::host_matrix_view knn_graph, - raft::host_matrix_view detour_count) -{ - RAFT_EXPECTS(knn_graph.extent(0) == detour_count.extent(0), - "knn_graph and detour_count are expected to have the same number of rows"); - RAFT_EXPECTS(knn_graph.extent(1) == detour_count.extent(1), - "knn_graph and detour_count are expected to have the same number of cols"); - const uint64_t graph_size = knn_graph.extent(0); - const uint64_t graph_degree = knn_graph.extent(1); - -#pragma omp parallel for - for (IdxT iA = 0; iA < graph_size; iA++) { - // Create a list of nodes, iB_candidates, that can be reached in 2-hops from node A. - auto iB_candidates = - raft::make_host_vector((graph_degree - 1) * (graph_degree - 1)); - for (uint64_t kAC = 0; kAC < graph_degree - 1; kAC++) { - IdxT iC = knn_graph(iA, kAC); - for (uint64_t kCB = 0; kCB < graph_degree - 1; kCB++) { - IdxT iB_candidate; - if (iC == iA || iC >= graph_size) { - iB_candidate = graph_size; - } else { - iB_candidate = knn_graph(iC, kCB); - if (iB_candidate == iA || iB_candidate == iC) { iB_candidate = graph_size; } - } - uint64_t idx; - if (kAC < kCB) { - idx = (kCB * kCB) + kAC; - } else { - idx = (kAC * (kAC + 1)) + kCB; - } - iB_candidates(idx) = iB_candidate; - } - } - // Count how many 2-hop detours are on each edge of node A. - for (uint64_t kAB = 0; kAB < graph_degree; kAB++) { - constexpr uint32_t max_count = 255; - uint32_t count = 0; - IdxT iB = knn_graph(iA, kAB); - if (iB == iA) { - count = max_count; - } else { - for (uint64_t idx = 0; idx < kAB * kAB; idx++) { - if (iB_candidates(idx) == iB) { count += 1; } - } - } - detour_count(iA, kAB) = std::min(count, max_count); - } - } -} - // // Prune unimportant edges based on 2-hop detour counts. // @@ -1678,16 +1562,18 @@ void prune_graph_gpu(raft::resources const& res, device_matrix_view_from_host d_input_graph( res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree)); - auto d_output_graph = raft::make_device_mdarray( - res, default_ws_mr, raft::make_extents(batch_size, output_graph_degree)); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); + auto d_output_graph = raft::make_device_mdarray( + res, + default_ws_mr, + raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { const uint32_t num_warps = 4; const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); + const dim3 blocks_prune(batch_size / num_warps, 1, 1); const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); kern_fused_prune <<>>( @@ -1775,54 +1661,36 @@ void optimize(raft::resources const& res, raft::common::nvtx::range fun_scope( "cagra::graph::optimize(%zu, %zu, %u)", graph_size, knn_graph_degree, output_graph_degree); - // check if input and output are both device accessible - // in this case we assume data to be ONLY device accessible and not host accessible - // furthermore we ensure all large allocations to go to the large workspace resource - // and all small allocations to go to the default workspace resource - bool inout_device_accessible = false; - { - bool input_device_accessible = is_ptr_device_accessible(knn_graph.data_handle()); - bool output_device_accessible = is_ptr_device_accessible(new_graph.data_handle()); - RAFT_EXPECTS(input_device_accessible == output_device_accessible, - "Input and output must be either both device accessible or both host accessible"); - inout_device_accessible = input_device_accessible && output_device_accessible; - } - // MST optimization // currently, only using GPU path for MST optimization - auto p_mst_graph = raft::make_pinned_matrix(res, 0, 0); - auto p_mst_graph_num_edges = raft::make_pinned_vector(res, graph_size); - auto p_mst_graph_num_edges_ptr = p_mst_graph_num_edges.data_handle(); -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - p_mst_graph_num_edges_ptr[i] = 0; - } + auto mst_graph = raft::make_host_matrix(0, 0); + auto mst_graph_num_edges = raft::make_host_vector(0); + if (guarantee_connectivity) { + auto mst_graph_num_edges = raft::make_host_vector(graph_size); + auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + mst_graph_num_edges_ptr[i] = 0; + } raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_connectivity"); - p_mst_graph = raft::make_pinned_matrix( - res, graph_size, output_graph_degree); + mst_graph = + raft::make_host_matrix(graph_size, output_graph_degree); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); - mst_optimization( - res, knn_graph, p_mst_graph.view(), p_mst_graph_num_edges.view(), use_gpu); + mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); for (uint64_t i = 0; i < graph_size; i++) { if (i < 8 || i >= graph_size - 8) { - RAFT_LOG_DEBUG("# p_mst_graph_num_edges_ptr[%lu]: %u\n", i, p_mst_graph_num_edges_ptr[i]); + RAFT_LOG_DEBUG("# mst_graph_num_edges_ptr[%lu]: %u\n", i, mst_graph_num_edges_ptr[i]); } } } // prune graph -- will always use GPU path { - // should be noop in case input is already device accessible - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree)); - prune_graph_gpu(res, - d_input_graph.data_handle(), + knn_graph.data_handle(), graph_size, knn_graph_degree, new_graph.data_handle(), @@ -1846,51 +1714,14 @@ void optimize(raft::resources const& res, RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", (time_make_end - time_make_start) * 1000.0); - // merge graph -- will use GPU path if possible, otherwise CPU path - // we only need to check in case output is not already device accessible - bool use_gpu_merge = use_gpu; - if (!inout_device_accessible) { - try { - auto d_new_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for merging on GPU"); - use_gpu_merge = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for merging on GPU (logic error)"); - use_gpu_merge = false; - } - } - - if (use_gpu_merge) { - // should be noop in case output is already device accessible - device_matrix_view_from_host d_new_graph( - res, - raft::make_host_matrix_view( - new_graph.data_handle(), graph_size, output_graph_degree)); - + // merge graph -- will always use GPU path + { merge_graph_gpu(res, - d_new_graph.data_handle(), + new_graph.data_handle(), d_rev_graph.data_handle(), d_rev_graph_count.data_handle(), - p_mst_graph.data_handle(), - p_mst_graph_num_edges.data_handle(), - graph_size, - output_graph_degree, - guarantee_connectivity); - - if (d_new_graph.allocated_memory()) { raft::copy(res, new_graph, d_new_graph.view()); } - } else { - auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); - auto rev_graph_count = raft::make_host_vector(graph_size); - raft::copy(res, rev_graph.view(), d_rev_graph.view()); - raft::copy(res, rev_graph_count.view(), d_rev_graph_count.view()); - - merge_graph_cpu(new_graph.data_handle(), - rev_graph.data_handle(), - rev_graph_count.data_handle(), - p_mst_graph.data_handle(), - p_mst_graph_num_edges_ptr, + mst_graph.data_handle(), + mst_graph_num_edges.data_handle(), graph_size, output_graph_degree, guarantee_connectivity); From 00c42045aa9f0f7d148865ec7e570078e5f16658 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 6 Mar 2026 00:01:10 +0000 Subject: [PATCH 13/22] batch reverse creation --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 116 ++++++++---------- 1 file changed, 52 insertions(+), 64 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 392edc97d9..9f2bc09d86 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -157,38 +157,24 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, } } -template -__global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] - IdxT* const rev_graph, // [size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree) -{ - const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); - const uint32_t tnum = blockDim.x * gridDim.x; - - for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { - const IdxT dest_id = dest_nodes[src_id]; - if (dest_id >= graph_size) continue; - - const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } - } -} - template -__global__ void kern_make_rev_graph_k(const IdxT* const dest_nodes, // [graph_size] - IdxT* const rev_graph, // [size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree, - uint64_t k) +__global__ void kern_rev_graph_batched(const IdxT* const dest_nodes, // [batch_size, degree] + IdxT* const rev_graph, // [graph_size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree, + const uint32_t batch_size, + const uint32_t batch_id) { const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); const uint64_t tnum = blockDim.x * gridDim.x; - for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { - IdxT dest_id = dest_nodes[k + (degree * src_id)]; + const uint64_t block_batch_size = min(batch_size, graph_size - batch_id * batch_size); + + for (uint64_t idx = tid; idx < block_batch_size * degree; idx += tnum) { + const IdxT dest_id = dest_nodes[idx]; + const uint32_t src_id = idx / degree; + if (dest_id >= graph_size) continue; const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); @@ -866,22 +852,18 @@ void merge_graph_gpu(raft::resources const& res, (merge_graph_end - merge_graph_start) * 1000.0); } -template +template void make_reverse_graph_gpu(raft::resources const& res, + IdxT* output_graph_ptr, IdxT* d_rev_graph_ptr, uint32_t* d_rev_graph_count_ptr, - InOutMatrixView new_graph) + uint64_t graph_size, + uint64_t output_graph_degree) { - const uint64_t graph_size = new_graph.extent(0); - const uint64_t output_graph_degree = new_graph.extent(1); - const IdxT* output_graph_ptr = new_graph.data_handle(); - raft::common::nvtx::range block_scope( "cagra::graph::optimize/reverse"); - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = raft::make_device_mdarray( - res, raft::resource::get_workspace_resource(res), raft::make_extents(graph_size)); + auto default_ws_mr = raft::resource::get_workspace_resource(res); raft::matrix::fill( res, @@ -893,36 +875,38 @@ void make_reverse_graph_gpu(raft::resources const& res, raft::make_device_vector_view(d_rev_graph_count_ptr, graph_size), uint32_t(0)); - bool output_graph_device_accessible = is_ptr_device_accessible(output_graph_ptr); - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); + const uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - for (uint64_t k = 0; k < output_graph_degree; k++) { - if (output_graph_device_accessible) { - kern_make_rev_graph_k<<>>( - output_graph_ptr, - d_rev_graph_ptr, - d_rev_graph_count_ptr, - static_cast(graph_size), - static_cast(output_graph_degree), - k); - } else { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; - } - raft::resource::sync_stream(res); + bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); + auto d_output_graph = raft::make_device_mdarray( + res, + default_ws_mr, + raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); - raft::copy(res, d_dest_nodes.view(), dest_nodes.view()); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph_ptr, - d_rev_graph_count_ptr, - static_cast(graph_size), - static_cast(output_graph_degree)); + if (!output_device_accessible) { + size_t copy_size = + std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * + output_graph_degree; + raft::copy(d_output_graph.data_handle(), + output_graph_ptr + i_batch * batch_size * output_graph_degree, + copy_size, + raft::resource::get_cuda_stream(res)); } - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %lu \r", k, output_graph_degree); + kern_rev_graph_batched<<>>( + output_device_accessible ? output_graph_ptr + (i_batch * batch_size * output_graph_degree) + : d_output_graph.data_handle(), + d_rev_graph_ptr, + d_rev_graph_count_ptr, + static_cast(graph_size), + static_cast(output_graph_degree), + static_cast(batch_size), + static_cast(i_batch)); } raft::resource::sync_stream(res); @@ -1707,8 +1691,12 @@ void optimize(raft::resources const& res, const double time_make_start = cur_time(); - make_reverse_graph_gpu( - res, d_rev_graph.data_handle(), d_rev_graph_count.data_handle(), new_graph); + make_reverse_graph_gpu(res, + new_graph.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); const double time_make_end = cur_time(); RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", From 9e63a7c442d6725703cbb52e575b68e2625f0694 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 6 Mar 2026 12:05:48 +0000 Subject: [PATCH 14/22] add prefetch view to handle managed & host --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 17 +- cpp/src/neighbors/detail/cagra/utils.hpp | 300 +++++++++++++++++- 2 files changed, 313 insertions(+), 4 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 9f2bc09d86..28006fa133 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -1543,8 +1543,19 @@ void prune_graph_gpu(raft::resources const& res, auto host_stats = raft::make_host_vector(2); raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); - device_matrix_view_from_host d_input_graph( - res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree)); + // device_matrix_view_from_host d_input_graph( + // res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, + // knn_graph_degree)); + + batched_device_view_from_host d_input_graph( + res, + raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree), + /*batch_size*/ graph_size, + /*read_only*/ true, + /*host_writeback*/ false, + /*initialize*/ true, + /*evict*/ true); + auto input_view = d_input_graph.next_view(); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); @@ -1561,7 +1572,7 @@ void prune_graph_gpu(raft::resources const& res, const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); kern_fused_prune <<>>( - d_input_graph.data_handle(), + input_view.data_handle(), output_device_accessible ? output_graph_ptr + i_batch * batch_size * output_graph_degree : d_output_graph.data_handle(), graph_size, diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index a59ac7fd57..c3d15e59f4 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -9,9 +9,13 @@ #include #include #include +#include +#include +#include +#include +#include #include #include - #include #include @@ -308,4 +312,298 @@ void copy_with_padding( raft::resource::get_cuda_stream(res))); } } + +/** + * Utility to create a batched device view from a host view + * + * This utility will create a batched device view from a host view and will handle the prefetch and + * writeback of the data Each batch can be referenced exactlyonce by calling the next_view() + * function + * + * @tparam T The type of the data + * @tparam IdxT The type of the index + * @param res The resources + * @param host_view The host view to create the batched device view from + * @param batch_size The batch size + * @param read_only Whether the data is read only (only for managed memory) + * @param host_writeback Whether to write back the data to the host (only for host memory) + * @param initialize Whether to initialize the data (only for managed memory) + * @param evict Whether to evict the data (only for managed memory) + * + * @return The batched device view + */ +template +class batched_device_view_from_host { + public: + batched_device_view_from_host(raft::resources const& res, + raft::host_matrix_view host_view, + uint64_t batch_size, + bool read_only = false, + bool host_writeback = false, + bool initialize = true, + bool evict = false) + : res_(res), + host_view_(host_view), + batch_size_(batch_size), + offset_(0), + batch_id_(0), + num_buffers_(2), + read_only_(read_only), + host_writeback_(host_writeback), + next_buffer_pos_(0), + evict_(evict), + initialize_(initialize) + { + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); + mem_type_ = attr.type; + // cudaMemoryTypeUnregistered = 0 + // cudaMemoryTypeHost = 1 + // cudaMemoryTypeDevice = 2 + // cudaMemoryTypeManaged = 3 + + prefetch_stream_ = raft::resource::get_cuda_stream(res); + writeback_stream_ = raft::resource::get_cuda_stream(res); + if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL)) { + if (raft::resource::get_stream_pool_size(res) >= 1) { + prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); + writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); + } + } + + // allocations + if (mem_type_ == cudaMemoryTypeHost || mem_type_ == cudaMemoryTypeUnregistered) { + device_mem_[0].emplace(raft::make_device_mdarray( + res, + raft::resource::get_large_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[0] = device_mem_[0]->data_handle(); + if (batch_size < static_cast(host_view.extent(0))) { + device_mem_[1].emplace(raft::make_device_mdarray( + res, + raft::resource::get_large_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[1] = device_mem_[1]->data_handle(); + } + if (host_writeback_ && batch_size * 2 < static_cast(host_view.extent(0))) { + num_buffers_ = 3; + device_mem_[2].emplace(raft::make_device_mdarray( + res, + raft::resource::get_large_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[2] = device_mem_[2]->data_handle(); + } + } + + // if data is managed and not for_write_ we can set the attribute on the device ptr + if (mem_type_ == cudaMemoryTypeManaged) { + // location_.type = CU_MEM_LOCATION_TYPE_DEVICE; + location_.type = cudaMemLocationTypeDevice; + location_.id = static_cast(raft::resource::get_device_id(res_)); + if (read_only_) { +#if CUDA_VERSION >= 13000 + RAFT_CUDA_TRY(cudaMemAdvise(host_view_.data_handle(), + host_view_.extent(0) * host_view_.extent(1) * sizeof(T), + cudaMemAdviseSetReadMostly, + location_)); +#else + RAFT_CUDA_TRY(cudaMemAdvise_v2(host_view_.data_handle(), + host_view_.extent(0) * host_view_.extent(1) * sizeof(T), + cudaMemAdviseSetReadMostly, + location_)); +#endif + // TODO maybe also reset upon destruction + } + } + + // prefetch next batch (0) + prefetch_next_batch(); + } + + bool prefetch_next_batch() + { + // this function will ensure the device_ptr [next_buffer_pos_] is pointing to the correct memory + // after the next synchronization with the prefetch stream + + // if data is on host and we are writing to it we will have to copy it back + // if data is on host we will have to copy it to the device_ptr + + // if data is managed and evict_ is true we can evict the data from device memory + // if data is managed we have to prefetch it + + bool next_batch_exists = offset_ < static_cast(host_view_.extent(0)); + + if (next_batch_exists) { + actual_batch_size_[next_buffer_pos_] = + next_batch_exists ? min(batch_size_, host_view_.extent(0) - offset_) : 0; + + switch (mem_type_) { + case cudaMemoryTypeManaged: +#if CUDA_VERSION >= 13000 + if (evict_ && batch_id_ > 1) { + // evict last active + CUdeviceptr dptrs[] = {device_ptr[next_buffer_pos_]}; + size_t sizes[] = {batch_size_ * host_view_.extent(1) * sizeof(T)}; + size_t prefetchLocIdxs[] = {0}; + RAFT_CUDA_TRY(cuMemDiscardBatchAsync( + dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); + } +#endif + // prefetch + device_ptr[next_buffer_pos_] = host_view_.data_handle() + offset_ * host_view_.extent(1); + if (initialize_) { + // managed API call to prefetch async +#if CUDA_VERSION >= 13000 + RAFT_CUDA_TRY(cudaMemPrefetchAsync( + device_ptr[next_buffer_pos_], + actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * sizeof(T), + location_, + 0, + prefetch_stream_)); +#else + RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2( + device_ptr[next_buffer_pos_], + actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * sizeof(T), + location_, + 0, + prefetch_stream_)); +#endif + } else { + // managed API call to cuMemDiscardAndPrefetchBatchAsync (discard and prefetch batch) +#if CUDA_VERSION >= 13000 + CUdeviceptr dptrs[] = {device_ptr[next_buffer_pos_]}; + size_t sizes[] = {actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * + sizeof(T)}; + size_t prefetchLocIdxs[] = {0}; + RAFT_CUDA_TRY(cuMemDiscardAndPrefetchBatchAsync( + dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); +#endif + } + + break; + case cudaMemoryTypeHost: + case cudaMemoryTypeUnregistered: + if (host_writeback_ && batch_id_ > 1) { + writeback_stream_.synchronize(); + // copy back last active + uint32_t writeback_pos = (next_buffer_pos_ + num_buffers_ - 2) % num_buffers_; + uint64_t writeback_offset = (offset_ - 2 * batch_size_) * host_view_.extent(1); + raft::copy(host_view_.data_handle() + writeback_offset, + device_ptr[writeback_pos], + actual_batch_size_[writeback_pos] * host_view_.extent(1), + writeback_stream_); + } + if (initialize_) { + // prefetch next position + raft::copy(device_ptr[next_buffer_pos_], + host_view_.data_handle() + offset_ * host_view_.extent(1), + actual_batch_size_[next_buffer_pos_] * host_view_.extent(1), + prefetch_stream_); + } + + break; + case cudaMemoryTypeDevice: + // just move pointer to next position + device_ptr[next_buffer_pos_] = host_view_.data_handle() + offset_ * host_view_.extent(1); + break; + } + + offset_ += actual_batch_size_[next_buffer_pos_]; + // swap next_buffer_pos_ + next_buffer_pos_ = (next_buffer_pos_ + 1) % num_buffers_; + } + + return next_batch_exists; + } + + ~batched_device_view_from_host() noexcept + { + prefetch_stream_.synchronize(); + writeback_stream_.synchronize(); + raft::resource::sync_stream(res_); + + // if data is on host and for_write --> make sure to copy back last active + // if data is managed and evict --> evict last active + + // make sure to sync on prefetch & writeback stream & res + switch (mem_type_) { + case cudaMemoryTypeManaged: +#if CUDA_VERSION >= 13000 + if (evict_ && batch_id_ > 0) { + // managed API call to evict 2 + uint32_t evict_pos = (next_buffer_pos_ + num_buffers_ - 1) % num_buffers_; + CUdeviceptr dptrs[] = {device_ptr[evict_pos]}; + size_t sizes[] = {batch_size_ * host_view_.extent(1) * sizeof(T)}; + size_t prefetchLocIdxs[] = {0}; + RAFT_CUDA_TRY(cuMemDiscardBatchAsync( + dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); + } + prefetch_stream_.synchronize(); +#endif + break; + case cudaMemoryTypeHost: + case cudaMemoryTypeUnregistered: + if (host_writeback_ && batch_id_ > 0) { + // TODO managed API call to copy back last active + uint32_t writeback_pos = (next_buffer_pos_ + num_buffers_ - 1) % num_buffers_; + uint64_t writeback_offset = + (offset_ - actual_batch_size_[writeback_pos]) * host_view_.extent(1); + raft::copy(host_view_.data_handle() + writeback_offset, + device_ptr[writeback_pos], + actual_batch_size_[writeback_pos] * host_view_.extent(1), + writeback_stream_); + } + writeback_stream_.synchronize(); + break; + case cudaMemoryTypeDevice: break; + } + } + + /** + * Returns the next view of the batch + * + * This function will ensure the next batch is ready and will trigger the prefetch of the + * subsequent next batch + * + * @return The next view of the batch + */ + raft::device_matrix_view next_view() + { + RAFT_EXPECTS(batch_id_ * batch_size_ < host_view_.extent(0), "Batch index out of bounds"); + + // ensure current batch is ready + prefetch_stream_.synchronize(); + + // trigger prefetch of next batch + bool next_batch_exists = prefetch_next_batch(); + + batch_id_++; + + uint32_t current_pos = + (next_buffer_pos_ + num_buffers_ - (next_batch_exists ? 2 : 1)) % num_buffers_; + return raft::make_device_matrix_view( + device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); + } + + private: + cudaMemoryType mem_type_; + const raft::resources& res_; + uint64_t batch_size_; + uint64_t offset_; + uint64_t num_buffers_; + bool initialize_; + rmm::cuda_stream_view prefetch_stream_; + rmm::cuda_stream_view writeback_stream_; + bool read_only_; + bool host_writeback_; + bool evict_; + int32_t next_buffer_pos_; + int32_t batch_id_; + cudaMemLocation location_; + std::optional> device_mem_[3]; + raft::host_matrix_view host_view_; + T* device_ptr[3]; + uint32_t actual_batch_size_[3]; +}; + } // namespace cuvs::neighbors::cagra::detail From a38ad525570d31882a1c86ff04eb679a6b1c4476 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 9 Mar 2026 20:49:08 +0000 Subject: [PATCH 15/22] fix batched iterator --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 123 +++---- cpp/src/neighbors/detail/cagra/utils.hpp | 313 ++++++++++-------- 2 files changed, 233 insertions(+), 203 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 28006fa133..ef8b1f8daf 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -22,6 +22,7 @@ #include #include +#include #include @@ -324,14 +325,14 @@ __device__ void thread_shift_array(T* array, uint64_t num) } template -__global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, output_graph_degree] - const IdxT* const rev_graph, +__global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, output_graph_degree] + const IdxT* const rev_graph, // [graph_size, output_graph_degree] uint32_t* const rev_graph_count, // [graph_size] const uint32_t graph_size, const uint32_t output_graph_degree, - const IdxT* const mst_graph, + const IdxT* const mst_graph, // [batch_size, output_graph_degree] const uint32_t mst_graph_degree, - const uint32_t* const mst_graph_num_edges_ptr, + const uint32_t* const mst_graph_num_edges_ptr, // [batch_size] const uint32_t batch_size, const uint32_t batch_id, bool guarantee_connectivity, @@ -350,12 +351,12 @@ __global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, output_gra if (nid >= graph_size) { return; } - const auto mst_graph_num_edges = guarantee_connectivity ? mst_graph_num_edges_ptr[nid] : 0; + const auto mst_graph_num_edges = guarantee_connectivity ? mst_graph_num_edges_ptr[nid_batch] : 0; // If guarantee_connectivity == true, use a temporal list to merge the // neighbor lists of the graphs. if (guarantee_connectivity) { for (uint32_t i = lane_id; i < mst_graph_degree; i += raft::WarpSize) { - smem_sorted_output_graph[i] = mst_graph[nid * mst_graph_degree + i]; + smem_sorted_output_graph[i] = mst_graph[nid_batch * mst_graph_degree + i]; } __syncwarp(); for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; @@ -788,52 +789,54 @@ void merge_graph_gpu(raft::resources const& res, std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); - auto d_output_graph = raft::make_device_mdarray( + batched_device_view_from_host d_output_graph( res, - default_ws_mr, - raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); + raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ true, + /*initialize*/ true, + /*hmm_as_managed*/ false); - device_matrix_view_from_host d_mst_graph( + batched_device_view_from_host d_mst_graph( res, raft::make_host_matrix_view( - mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree)); + mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true, + /*hmm_as_managed*/ false); - device_matrix_view_from_host d_mst_graph_num_edges( + batched_device_view_from_host d_mst_graph_num_edges( res, - raft::make_host_matrix_view( - mst_graph_num_edges_ptr, guarantee_connectivity ? graph_size : 0, 1)); + raft::make_host_matrix_view( + mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true, + /*hmm_as_managed*/ false); const uint32_t num_warps = 4; const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); - const dim3 blocks_merge(batch_size / num_warps, 1, 1); + const dim3 blocks_merge(raft::ceildiv(batch_size, num_warps), 1, 1); const size_t merge_smem_size = num_warps * output_graph_degree * sizeof(IdxT); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + auto mst_graph_view = d_mst_graph.next_view(); + auto mst_graph_num_edges_view = d_mst_graph_num_edges.next_view(); + auto output_view = d_output_graph.next_view(); kern_merge_graph <<>>( - output_device_accessible ? output_graph_ptr + (i_batch * batch_size * output_graph_degree) - : d_output_graph.data_handle(), + output_view.data_handle(), d_rev_graph_ptr, d_rev_graph_count_ptr, static_cast(graph_size), static_cast(output_graph_degree), - d_mst_graph.data_handle(), + mst_graph_view.data_handle(), static_cast(output_graph_degree), - d_mst_graph_num_edges.data_handle(), + mst_graph_num_edges_view.data_handle(), batch_size, i_batch, guarantee_connectivity, d_check_num_protected_edges.data_handle()); - - if (!output_device_accessible) { - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, - d_output_graph.data_handle(), - copy_size, - raft::resource::get_cuda_stream(res)); - } } bool check_num_protected_edges = true; @@ -879,28 +882,21 @@ void make_reverse_graph_gpu(raft::resources const& res, std::min(static_cast(graph_size), static_cast(256 * 1024)); const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); - auto d_output_graph = raft::make_device_mdarray( + batched_device_view_from_host d_output_graph( res, - default_ws_mr, - raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); + raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true, + /*hmm_as_managed*/ false); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { dim3 threads(256, 1, 1); dim3 blocks(1024, 1, 1); + auto output_view = d_output_graph.next_view(); - if (!output_device_accessible) { - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(d_output_graph.data_handle(), - output_graph_ptr + i_batch * batch_size * output_graph_degree, - copy_size, - raft::resource::get_cuda_stream(res)); - } kern_rev_graph_batched<<>>( - output_device_accessible ? output_graph_ptr + (i_batch * batch_size * output_graph_degree) - : d_output_graph.data_handle(), + output_view.data_handle(), d_rev_graph_ptr, d_rev_graph_count_ptr, static_cast(graph_size), @@ -1543,38 +1539,35 @@ void prune_graph_gpu(raft::resources const& res, auto host_stats = raft::make_host_vector(2); raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); - // device_matrix_view_from_host d_input_graph( - // res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, - // knn_graph_degree)); - batched_device_view_from_host d_input_graph( res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree), /*batch_size*/ graph_size, - /*read_only*/ true, /*host_writeback*/ false, /*initialize*/ true, - /*evict*/ true); + /*hmm_as_managed*/ true); auto input_view = d_input_graph.next_view(); - auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); - - bool output_device_accessible = is_ptr_device_accessible(output_graph_ptr); - auto d_output_graph = raft::make_device_mdarray( + batched_device_view_from_host d_output_graph( res, - default_ws_mr, - raft::make_extents(output_device_accessible ? 0 : batch_size, output_graph_degree)); + raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ true, + /*initialize*/ false, + /*hmm_as_managed*/ false); + + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + auto output_view = d_output_graph.next_view(); const uint32_t num_warps = 4; const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); - const dim3 blocks_prune(batch_size / num_warps, 1, 1); + const dim3 blocks_prune(raft::ceildiv(batch_size, num_warps), 1, 1); const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); kern_fused_prune <<>>( input_view.data_handle(), - output_device_accessible ? output_graph_ptr + i_batch * batch_size * output_graph_degree - : d_output_graph.data_handle(), + output_view.data_handle(), graph_size, knn_graph_degree, output_graph_degree, @@ -1583,16 +1576,6 @@ void prune_graph_gpu(raft::resources const& res, d_invalid_neighbor_list.data_handle(), dev_stats.data_handle()); - if (!output_device_accessible) { - size_t copy_size = - std::min(static_cast(batch_size), graph_size - i_batch * batch_size) * - output_graph_degree; - raft::copy(output_graph_ptr + i_batch * batch_size * output_graph_degree, - d_output_graph.data_handle(), - copy_size, - raft::resource::get_cuda_stream(res)); - } - raft::resource::sync_stream(res); RAFT_LOG_DEBUG( "# Pruning kNN Graph on GPUs (%.1lf %%)\r", diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index c3d15e59f4..df6ef1ce6f 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -23,6 +24,7 @@ #include #include +#include #include namespace cuvs::neighbors::cagra::detail { @@ -328,7 +330,7 @@ void copy_with_padding( * @param read_only Whether the data is read only (only for managed memory) * @param host_writeback Whether to write back the data to the host (only for host memory) * @param initialize Whether to initialize the data (only for managed memory) - * @param evict Whether to evict the data (only for managed memory) + * @param discard Whether to discard the data (only for managed memory) * * @return The batched device view */ @@ -338,22 +340,24 @@ class batched_device_view_from_host { batched_device_view_from_host(raft::resources const& res, raft::host_matrix_view host_view, uint64_t batch_size, - bool read_only = false, bool host_writeback = false, bool initialize = true, - bool evict = false) + bool hmm_as_managed = false) : res_(res), host_view_(host_view), batch_size_(batch_size), offset_(0), - batch_id_(0), + batch_id_(-2), num_buffers_(2), - read_only_(read_only), host_writeback_(host_writeback), - next_buffer_pos_(0), - evict_(evict), - initialize_(initialize) + initialize_(initialize), + hmm_as_managed_(hmm_as_managed) { + if (host_view.extent(0) == 0) { + mem_type_ = cudaMemoryTypeDevice; + return; + } + cudaPointerAttributes attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); mem_type_ = attr.type; @@ -361,27 +365,35 @@ class batched_device_view_from_host { // cudaMemoryTypeHost = 1 // cudaMemoryTypeDevice = 2 // cudaMemoryTypeManaged = 3 + // + // On HMM systems, unregistered (malloc) memory can have devicePointer != nullptr, + // meaning it's directly accessible from the GPU. Treat it like managed memory: + if (mem_type_ == cudaMemoryTypeUnregistered && attr.devicePointer != nullptr && + hmm_as_managed) { + mem_type_ = cudaMemoryTypeManaged; + } - prefetch_stream_ = raft::resource::get_cuda_stream(res); - writeback_stream_ = raft::resource::get_cuda_stream(res); - if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL)) { - if (raft::resource::get_stream_pool_size(res) >= 1) { - prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); - writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); - } + if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && + raft::resource::get_stream_pool_size(res) >= 1) { + prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); + writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); + } else { + local_stream_pool_ = std::make_shared(2); + prefetch_stream_ = local_stream_pool_.value()->get_stream(); + writeback_stream_ = local_stream_pool_.value()->get_stream(); } // allocations if (mem_type_ == cudaMemoryTypeHost || mem_type_ == cudaMemoryTypeUnregistered) { device_mem_[0].emplace(raft::make_device_mdarray( res, - raft::resource::get_large_workspace_resource(res), + raft::resource::get_workspace_resource(res), raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[0] = device_mem_[0]->data_handle(); if (batch_size < static_cast(host_view.extent(0))) { device_mem_[1].emplace(raft::make_device_mdarray( res, - raft::resource::get_large_workspace_resource(res), + raft::resource::get_workspace_resource(res), raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[1] = device_mem_[1]->data_handle(); } @@ -389,7 +401,7 @@ class batched_device_view_from_host { num_buffers_ = 3; device_mem_[2].emplace(raft::make_device_mdarray( res, - raft::resource::get_large_workspace_resource(res), + raft::resource::get_workspace_resource(res), raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[2] = device_mem_[2]->data_handle(); } @@ -400,18 +412,9 @@ class batched_device_view_from_host { // location_.type = CU_MEM_LOCATION_TYPE_DEVICE; location_.type = cudaMemLocationTypeDevice; location_.id = static_cast(raft::resource::get_device_id(res_)); - if (read_only_) { -#if CUDA_VERSION >= 13000 - RAFT_CUDA_TRY(cudaMemAdvise(host_view_.data_handle(), - host_view_.extent(0) * host_view_.extent(1) * sizeof(T), - cudaMemAdviseSetReadMostly, - location_)); -#else - RAFT_CUDA_TRY(cudaMemAdvise_v2(host_view_.data_handle(), - host_view_.extent(0) * host_view_.extent(1) * sizeof(T), - cudaMemAdviseSetReadMostly, - location_)); -#endif + if (!host_writeback_) { + advise_read_mostly(host_view_.data_handle(), + host_view_.extent(0) * host_view_.extent(1) * sizeof(T)); // TODO maybe also reset upon destruction } } @@ -422,95 +425,72 @@ class batched_device_view_from_host { bool prefetch_next_batch() { - // this function will ensure the device_ptr [next_buffer_pos_] is pointing to the correct memory - // after the next synchronization with the prefetch stream + batch_id_++; + + // ensure previous batch at position batch_id_ is ready + prefetch_stream_.synchronize(); + if (host_writeback_) { writeback_stream_.synchronize(); } - // if data is on host and we are writing to it we will have to copy it back - // if data is on host we will have to copy it to the device_ptr + // this step will + // * write back data from batch_id_ - 1 + // * prefetch data for batch_id_ + 1 - // if data is managed and evict_ is true we can evict the data from device memory - // if data is managed we have to prefetch it + // if data is on host and host_writeback_ is true we will have to copy it back + // if data is on host and initialize_ is true we will have to copy it to the device_ptr + + // if data is managed and !host_writeback_ we can discard the data from device memory + // if data is managed and initialize_ is true we can prefetch it to the device + // if data is managed and !initialize_ we can discard and prefetch the data location + + // if data is on device only this is almost a noop, just prepping the pointers + + RAFT_EXPECTS(offset_ <= host_view_.extent(0), "Offset out of bounds"); bool next_batch_exists = offset_ < static_cast(host_view_.extent(0)); if (next_batch_exists) { - actual_batch_size_[next_buffer_pos_] = - next_batch_exists ? min(batch_size_, host_view_.extent(0) - offset_) : 0; + // synchronize to ensure all previous operations are completed + // in particular all work on batch_id_ - 1 + raft::resource::sync_stream(res_); + + int32_t prefetch_pos = (batch_id_ + 1) % num_buffers_; + actual_batch_size_[prefetch_pos] = min(batch_size_, host_view_.extent(0) - offset_); switch (mem_type_) { case cudaMemoryTypeManaged: -#if CUDA_VERSION >= 13000 - if (evict_ && batch_id_ > 1) { - // evict last active - CUdeviceptr dptrs[] = {device_ptr[next_buffer_pos_]}; - size_t sizes[] = {batch_size_ * host_view_.extent(1) * sizeof(T)}; - size_t prefetchLocIdxs[] = {0}; - RAFT_CUDA_TRY(cuMemDiscardBatchAsync( - dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); + if (!host_writeback_ && batch_id_ > 1) { + uint32_t discard_pos = (batch_id_ - 1) % num_buffers_; + size_t discard_size = batch_size_ * host_view_.extent(1) * sizeof(T); + discard_managed_region(device_ptr[discard_pos], discard_size); } -#endif - // prefetch - device_ptr[next_buffer_pos_] = host_view_.data_handle() + offset_ * host_view_.extent(1); - if (initialize_) { - // managed API call to prefetch async -#if CUDA_VERSION >= 13000 - RAFT_CUDA_TRY(cudaMemPrefetchAsync( - device_ptr[next_buffer_pos_], - actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * sizeof(T), - location_, - 0, - prefetch_stream_)); -#else - RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2( - device_ptr[next_buffer_pos_], - actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * sizeof(T), - location_, - 0, - prefetch_stream_)); -#endif - } else { - // managed API call to cuMemDiscardAndPrefetchBatchAsync (discard and prefetch batch) -#if CUDA_VERSION >= 13000 - CUdeviceptr dptrs[] = {device_ptr[next_buffer_pos_]}; - size_t sizes[] = {actual_batch_size_[next_buffer_pos_] * host_view_.extent(1) * - sizeof(T)}; - size_t prefetchLocIdxs[] = {0}; - RAFT_CUDA_TRY(cuMemDiscardAndPrefetchBatchAsync( - dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); -#endif - } - + // prefetch next position + device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); + prefetch_managed_region( + device_ptr[prefetch_pos], + actual_batch_size_[prefetch_pos] * host_view_.extent(1) * sizeof(T)); break; case cudaMemoryTypeHost: case cudaMemoryTypeUnregistered: - if (host_writeback_ && batch_id_ > 1) { - writeback_stream_.synchronize(); + if (host_writeback_ && batch_id_ > 0) { // copy back last active - uint32_t writeback_pos = (next_buffer_pos_ + num_buffers_ - 2) % num_buffers_; - uint64_t writeback_offset = (offset_ - 2 * batch_size_) * host_view_.extent(1); - raft::copy(host_view_.data_handle() + writeback_offset, - device_ptr[writeback_pos], - actual_batch_size_[writeback_pos] * host_view_.extent(1), - writeback_stream_); + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); } if (initialize_) { // prefetch next position - raft::copy(device_ptr[next_buffer_pos_], - host_view_.data_handle() + offset_ * host_view_.extent(1), - actual_batch_size_[next_buffer_pos_] * host_view_.extent(1), - prefetch_stream_); + prefetch_from_host_to_device( + device_ptr[prefetch_pos], offset_, actual_batch_size_[prefetch_pos]); } break; case cudaMemoryTypeDevice: // just move pointer to next position - device_ptr[next_buffer_pos_] = host_view_.data_handle() + offset_ * host_view_.extent(1); + device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); break; } - offset_ += actual_batch_size_[next_buffer_pos_]; - // swap next_buffer_pos_ - next_buffer_pos_ = (next_buffer_pos_ + 1) % num_buffers_; + offset_ += actual_batch_size_[prefetch_pos]; } return next_batch_exists; @@ -525,33 +505,36 @@ class batched_device_view_from_host { // if data is on host and for_write --> make sure to copy back last active // if data is managed and evict --> evict last active - // make sure to sync on prefetch & writeback stream & res + // make sure to sync on prefetch stream & res switch (mem_type_) { case cudaMemoryTypeManaged: -#if CUDA_VERSION >= 13000 - if (evict_ && batch_id_ > 0) { - // managed API call to evict 2 - uint32_t evict_pos = (next_buffer_pos_ + num_buffers_ - 1) % num_buffers_; - CUdeviceptr dptrs[] = {device_ptr[evict_pos]}; - size_t sizes[] = {batch_size_ * host_view_.extent(1) * sizeof(T)}; - size_t prefetchLocIdxs[] = {0}; - RAFT_CUDA_TRY(cuMemDiscardBatchAsync( - dptrs, sizes, 1, &location_, prefetchLocIdxs, 1, 0, prefetch_stream_)); + if (!host_writeback_) { + uint32_t discard_pos = batch_id_ % num_buffers_; + size_t discard_size_rows = actual_batch_size_[discard_pos]; + if (batch_id_ > 0) { + discard_pos = (batch_id_ - 1) % num_buffers_; + discard_size_rows += batch_size_; + } + discard_managed_region(device_ptr[discard_pos], + discard_size_rows * host_view_.extent(1) * sizeof(T)); } - prefetch_stream_.synchronize(); -#endif + writeback_stream_.synchronize(); break; case cudaMemoryTypeHost: case cudaMemoryTypeUnregistered: - if (host_writeback_ && batch_id_ > 0) { - // TODO managed API call to copy back last active - uint32_t writeback_pos = (next_buffer_pos_ + num_buffers_ - 1) % num_buffers_; - uint64_t writeback_offset = - (offset_ - actual_batch_size_[writeback_pos]) * host_view_.extent(1); - raft::copy(host_view_.data_handle() + writeback_offset, - device_ptr[writeback_pos], - actual_batch_size_[writeback_pos] * host_view_.extent(1), - writeback_stream_); + if (host_writeback_) { + uint32_t writeback_pos_last = batch_id_ % num_buffers_; + if (batch_id_ > 0) { + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); + } + { + uint64_t writeback_offset_last = batch_id_ * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos_last], + writeback_offset_last, + actual_batch_size_[writeback_pos_last]); + } } writeback_stream_.synchronize(); break; @@ -569,39 +552,103 @@ class batched_device_view_from_host { */ raft::device_matrix_view next_view() { - RAFT_EXPECTS(batch_id_ * batch_size_ < host_view_.extent(0), "Batch index out of bounds"); - - // ensure current batch is ready - prefetch_stream_.synchronize(); + // special case for empty host view + if (host_view_.extent(0) == 0) { + return raft::make_device_matrix_view(nullptr, 0, host_view_.extent(1)); + } // trigger prefetch of next batch bool next_batch_exists = prefetch_next_batch(); - batch_id_++; + RAFT_EXPECTS(batch_id_ * batch_size_ < host_view_.extent(0), "Batch index out of bounds"); - uint32_t current_pos = - (next_buffer_pos_ + num_buffers_ - (next_batch_exists ? 2 : 1)) % num_buffers_; + uint32_t current_pos = batch_id_ % num_buffers_; return raft::make_device_matrix_view( device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); } private: - cudaMemoryType mem_type_; - const raft::resources& res_; - uint64_t batch_size_; - uint64_t offset_; - uint64_t num_buffers_; - bool initialize_; + void advise_read_mostly(T* ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + RAFT_CUDA_TRY(cudaMemAdvise(ptr, size, cudaMemAdviseSetReadMostly, location_)); +#else + RAFT_CUDA_TRY(cudaMemAdvise_v2(ptr, size, cudaMemAdviseSetReadMostly, location_)); +#endif + } + + void discard_managed_region(T* dev_ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + void* dptrs[1] = {dev_ptr}; + size_t sizes[1] = {size}; + RAFT_CUDA_TRY(cudaMemDiscardBatchAsync(dptrs, sizes, 1, 0, writeback_stream_)); +#endif + // FIXME: CUDA12 does not support discard + } + + void prefetch_managed_region(T* dev_ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + if (initialize_) { + RAFT_CUDA_TRY(cudaMemPrefetchAsync(dev_ptr, size, location_, 0, prefetch_stream_)); + } else { + void* dptrs[1] = {dev_ptr}; + size_t sizes[1] = {size}; + RAFT_CUDA_TRY( + cudaMemDiscardAndPrefetchBatchAsync(dptrs, sizes, 1, location_, 0, prefetch_stream_)); + } +#else + // FIXME: CUDA12 does not support discard - so we just prefetch + if (initialize_) { + RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2(dev_ptr, size, location_, 0, prefetch_stream_)); + } else { + RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2(dev_ptr, size, location_, 0, prefetch_stream_)); + } +#endif + } + + void prefetch_from_host_to_device(T* dev_ptr, size_t src_row_offset, size_t num_rows) + { + raft::copy(dev_ptr, + host_view_.data_handle() + src_row_offset * host_view_.extent(1), + num_rows * host_view_.extent(1), + prefetch_stream_); + } + + void writeback_from_device_to_host(T* dev_ptr, size_t dst_row_offset, size_t num_rows) + { + raft::copy(host_view_.data_handle() + dst_row_offset * host_view_.extent(1), + dev_ptr, + num_rows * host_view_.extent(1), + writeback_stream_); + } + + // stream pool for local streams + std::optional> local_stream_pool_; rmm::cuda_stream_view prefetch_stream_; rmm::cuda_stream_view writeback_stream_; - bool read_only_; - bool host_writeback_; - bool evict_; - int32_t next_buffer_pos_; + + // configuration + const raft::resources& res_; + bool initialize_; // initialize the data on the device + bool host_writeback_; // write back the data to the host + bool hmm_as_managed_; // treat unregistered memory as managed memory + + // batch position information + uint64_t batch_size_; int32_t batch_id_; + uint64_t offset_; + cudaMemLocation location_; - std::optional> device_mem_[3]; + + // input pointer information + cudaMemoryType mem_type_; raft::host_matrix_view host_view_; + + // internal device buffers + uint64_t num_buffers_; + std::optional> device_mem_[3]; T* device_ptr[3]; uint32_t actual_batch_size_[3]; }; From 89b0d1c25bbff782cf906be7d9b2dc58a5927116 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 9 Mar 2026 21:57:25 +0000 Subject: [PATCH 16/22] implement fallback / simplify strategy --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 18 +-- cpp/src/neighbors/detail/cagra/utils.hpp | 110 ++++++++++-------- 2 files changed, 66 insertions(+), 62 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index ef8b1f8daf..a6e4c08350 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -794,8 +794,7 @@ void merge_graph_gpu(raft::resources const& res, raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ true, - /*initialize*/ true, - /*hmm_as_managed*/ false); + /*initialize*/ true); batched_device_view_from_host d_mst_graph( res, @@ -803,8 +802,7 @@ void merge_graph_gpu(raft::resources const& res, mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ false, - /*initialize*/ true, - /*hmm_as_managed*/ false); + /*initialize*/ true); batched_device_view_from_host d_mst_graph_num_edges( res, @@ -812,8 +810,7 @@ void merge_graph_gpu(raft::resources const& res, mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ false, - /*initialize*/ true, - /*hmm_as_managed*/ false); + /*initialize*/ true); const uint32_t num_warps = 4; const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); @@ -887,8 +884,7 @@ void make_reverse_graph_gpu(raft::resources const& res, raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ false, - /*initialize*/ true, - /*hmm_as_managed*/ false); + /*initialize*/ true); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { dim3 threads(256, 1, 1); @@ -1544,8 +1540,7 @@ void prune_graph_gpu(raft::resources const& res, raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree), /*batch_size*/ graph_size, /*host_writeback*/ false, - /*initialize*/ true, - /*hmm_as_managed*/ true); + /*initialize*/ true); auto input_view = d_input_graph.next_view(); batched_device_view_from_host d_output_graph( @@ -1553,8 +1548,7 @@ void prune_graph_gpu(raft::resources const& res, raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), /*batch_size*/ batch_size, /*host_writeback*/ true, - /*initialize*/ false, - /*hmm_as_managed*/ false); + /*initialize*/ false); auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index df6ef1ce6f..8f6cfb063f 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -327,22 +327,25 @@ void copy_with_padding( * @param res The resources * @param host_view The host view to create the batched device view from * @param batch_size The batch size - * @param read_only Whether the data is read only (only for managed memory) * @param host_writeback Whether to write back the data to the host (only for host memory) * @param initialize Whether to initialize the data (only for managed memory) - * @param discard Whether to discard the data (only for managed memory) * * @return The batched device view */ template class batched_device_view_from_host { public: + enum class memory_strategy { + device_only, // data is on device only (no copy needed) + copy_device, // data is explicitly moved to/from device buffers + managed_only, // data is on managed memory (system managed) + }; + batched_device_view_from_host(raft::resources const& res, raft::host_matrix_view host_view, uint64_t batch_size, bool host_writeback = false, - bool initialize = true, - bool hmm_as_managed = false) + bool initialize = true) : res_(res), host_view_(host_view), batch_size_(batch_size), @@ -350,29 +353,23 @@ class batched_device_view_from_host { batch_id_(-2), num_buffers_(2), host_writeback_(host_writeback), - initialize_(initialize), - hmm_as_managed_(hmm_as_managed) + initialize_(initialize) { if (host_view.extent(0) == 0) { - mem_type_ = cudaMemoryTypeDevice; + mem_strategy_ = memory_strategy::device_only; return; } cudaPointerAttributes attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); - mem_type_ = attr.type; - // cudaMemoryTypeUnregistered = 0 - // cudaMemoryTypeHost = 1 - // cudaMemoryTypeDevice = 2 - // cudaMemoryTypeManaged = 3 - // - // On HMM systems, unregistered (malloc) memory can have devicePointer != nullptr, - // meaning it's directly accessible from the GPU. Treat it like managed memory: - if (mem_type_ == cudaMemoryTypeUnregistered && attr.devicePointer != nullptr && - hmm_as_managed) { - mem_type_ = cudaMemoryTypeManaged; + switch (attr.type) { + case cudaMemoryTypeUnregistered: + case cudaMemoryTypeHost: + case cudaMemoryTypeManaged: mem_strategy_ = memory_strategy::copy_device; break; + case cudaMemoryTypeDevice: mem_strategy_ = memory_strategy::device_only; break; } + // setup streams if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && raft::resource::get_stream_pool_size(res) >= 1) { prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); @@ -383,32 +380,48 @@ class batched_device_view_from_host { writeback_stream_ = local_stream_pool_.value()->get_stream(); } - // allocations - if (mem_type_ == cudaMemoryTypeHost || mem_type_ == cudaMemoryTypeUnregistered) { - device_mem_[0].emplace(raft::make_device_mdarray( - res, - raft::resource::get_workspace_resource(res), - raft::make_extents(batch_size, host_view.extent(1)))); - device_ptr[0] = device_mem_[0]->data_handle(); - if (batch_size < static_cast(host_view.extent(0))) { - device_mem_[1].emplace(raft::make_device_mdarray( - res, - raft::resource::get_workspace_resource(res), - raft::make_extents(batch_size, host_view.extent(1)))); - device_ptr[1] = device_mem_[1]->data_handle(); - } - if (host_writeback_ && batch_size * 2 < static_cast(host_view.extent(0))) { - num_buffers_ = 3; - device_mem_[2].emplace(raft::make_device_mdarray( + // buffer allocations + if (mem_strategy_ == memory_strategy::copy_device) { + try { + device_mem_[0].emplace(raft::make_device_mdarray( res, raft::resource::get_workspace_resource(res), raft::make_extents(batch_size, host_view.extent(1)))); - device_ptr[2] = device_mem_[2]->data_handle(); + device_ptr[0] = device_mem_[0]->data_handle(); + if (batch_size < static_cast(host_view.extent(0))) { + device_mem_[1].emplace(raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[1] = device_mem_[1]->data_handle(); + } + if (host_writeback_ && batch_size * 2 < static_cast(host_view.extent(0))) { + num_buffers_ = 3; + device_mem_[2].emplace(raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[2] = device_mem_[2]->data_handle(); + } + } catch (std::bad_alloc& e) { + RAFT_LOG_DEBUG("Insufficient memory for device buffers"); + if (attr.devicePointer != nullptr) { + mem_strategy_ = memory_strategy::managed_only; + } else { + throw std::bad_alloc(); + } + } catch (raft::logic_error& e) { + RAFT_LOG_DEBUG("Insufficient memory for device buffers (logic error)"); + if (attr.devicePointer != nullptr) { + mem_strategy_ = memory_strategy::managed_only; + } else { + throw raft::logic_error("Insufficient memory for device buffers (logic error)"); + } } } // if data is managed and not for_write_ we can set the attribute on the device ptr - if (mem_type_ == cudaMemoryTypeManaged) { + if (mem_strategy_ == memory_strategy::managed_only) { // location_.type = CU_MEM_LOCATION_TYPE_DEVICE; location_.type = cudaMemLocationTypeDevice; location_.id = static_cast(raft::resource::get_device_id(res_)); @@ -428,7 +441,7 @@ class batched_device_view_from_host { batch_id_++; // ensure previous batch at position batch_id_ is ready - prefetch_stream_.synchronize(); + if (initialize_) { prefetch_stream_.synchronize(); } if (host_writeback_) { writeback_stream_.synchronize(); } // this step will @@ -456,8 +469,8 @@ class batched_device_view_from_host { int32_t prefetch_pos = (batch_id_ + 1) % num_buffers_; actual_batch_size_[prefetch_pos] = min(batch_size_, host_view_.extent(0) - offset_); - switch (mem_type_) { - case cudaMemoryTypeManaged: + switch (mem_strategy_) { + case memory_strategy::managed_only: if (!host_writeback_ && batch_id_ > 1) { uint32_t discard_pos = (batch_id_ - 1) % num_buffers_; size_t discard_size = batch_size_ * host_view_.extent(1) * sizeof(T); @@ -469,8 +482,7 @@ class batched_device_view_from_host { device_ptr[prefetch_pos], actual_batch_size_[prefetch_pos] * host_view_.extent(1) * sizeof(T)); break; - case cudaMemoryTypeHost: - case cudaMemoryTypeUnregistered: + case memory_strategy::copy_device: if (host_writeback_ && batch_id_ > 0) { // copy back last active uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; @@ -484,7 +496,7 @@ class batched_device_view_from_host { } break; - case cudaMemoryTypeDevice: + case memory_strategy::device_only: // just move pointer to next position device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); break; @@ -506,8 +518,8 @@ class batched_device_view_from_host { // if data is managed and evict --> evict last active // make sure to sync on prefetch stream & res - switch (mem_type_) { - case cudaMemoryTypeManaged: + switch (mem_strategy_) { + case memory_strategy::managed_only: if (!host_writeback_) { uint32_t discard_pos = batch_id_ % num_buffers_; size_t discard_size_rows = actual_batch_size_[discard_pos]; @@ -520,8 +532,7 @@ class batched_device_view_from_host { } writeback_stream_.synchronize(); break; - case cudaMemoryTypeHost: - case cudaMemoryTypeUnregistered: + case memory_strategy::copy_device: if (host_writeback_) { uint32_t writeback_pos_last = batch_id_ % num_buffers_; if (batch_id_ > 0) { @@ -538,7 +549,7 @@ class batched_device_view_from_host { } writeback_stream_.synchronize(); break; - case cudaMemoryTypeDevice: break; + case memory_strategy::device_only: break; } } @@ -630,10 +641,10 @@ class batched_device_view_from_host { rmm::cuda_stream_view writeback_stream_; // configuration + memory_strategy mem_strategy_; const raft::resources& res_; bool initialize_; // initialize the data on the device bool host_writeback_; // write back the data to the host - bool hmm_as_managed_; // treat unregistered memory as managed memory // batch position information uint64_t batch_size_; @@ -643,7 +654,6 @@ class batched_device_view_from_host { cudaMemLocation location_; // input pointer information - cudaMemoryType mem_type_; raft::host_matrix_view host_view_; // internal device buffers From d0e3daefdfc7fcdec3ceaaa62a8d95134a726f15 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 10 Mar 2026 17:31:23 +0000 Subject: [PATCH 17/22] add logging / remove stats compute --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 51 +++------------ cpp/src/neighbors/detail/cagra/utils.hpp | 62 ++++++++++++------- 2 files changed, 46 insertions(+), 67 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index a6e4c08350..b5e055820d 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -648,44 +648,6 @@ __global__ void kern_mst_opt_postprocessing(IdxT* outgoing_num_edges, // [graph } } -template -uint64_t pos_in_array(T val, const T* array, uint64_t num) -{ - for (uint64_t i = 0; i < num; i++) { - if (val == array[i]) { return i; } - } - return num; -} - -template -void shift_array(T* array, uint64_t num) -{ - for (uint64_t i = num; i > 0; i--) { - array[i] = array[i - 1]; - } -} - -template -void log_replaced_edges_stats(const IdxT* output_graph_ptr, - uint64_t graph_size, - uint64_t output_graph_degree) -{ - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/stats"); - uint64_t num_replaced_edges = 0; -#pragma omp parallel for reduction(+ : num_replaced_edges) - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - const uint64_t pos = - pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); - if (pos == output_graph_degree) { num_replaced_edges += 1; } - } - } - RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); -} - template void log_incoming_edges_histogram(const IdxT* output_graph_ptr, uint64_t graph_size, @@ -755,7 +717,10 @@ void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, for (uint32_t k = j + 1; k < output_graph_degree; k++) { const auto neighbor_b = my_out_graph[k]; - if (neighbor_a == neighbor_b) { num_dup++; } + if (neighbor_a == neighbor_b) { + num_dup++; + break; + } } } } @@ -1606,10 +1571,10 @@ void prune_graph_gpu(raft::resources const& res, } // TODO allow pinned input for both knn_graph and new_graph -template +template void optimize(raft::resources const& res, - InOutMatrixView knn_graph, - InOutMatrixView new_graph, + InputMatrixView knn_graph, + OutputMatrixView new_graph, const bool guarantee_connectivity = true, const bool use_gpu = true) { @@ -1707,8 +1672,6 @@ void optimize(raft::resources const& res, if (is_ptr_host_accessible(new_graph.data_handle())) { // following checks require host access - log_replaced_edges_stats(new_graph.data_handle(), graph_size, output_graph_degree); - log_incoming_edges_histogram(new_graph.data_handle(), graph_size, output_graph_degree); check_duplicates_and_out_of_range( diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 8f6cfb063f..75883a9636 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -360,25 +360,18 @@ class batched_device_view_from_host { return; } - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); - switch (attr.type) { + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr_, host_view.data_handle())); + switch (attr_.type) { case cudaMemoryTypeUnregistered: case cudaMemoryTypeHost: case cudaMemoryTypeManaged: mem_strategy_ = memory_strategy::copy_device; break; case cudaMemoryTypeDevice: mem_strategy_ = memory_strategy::device_only; break; } - // setup streams - if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && - raft::resource::get_stream_pool_size(res) >= 1) { - prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); - writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); - } else { - local_stream_pool_ = std::make_shared(2); - prefetch_stream_ = local_stream_pool_.value()->get_stream(); - writeback_stream_ = local_stream_pool_.value()->get_stream(); - } + RAFT_LOG_DEBUG("Memory strategy: %d for type %d, size %zu", + static_cast(mem_strategy_), + static_cast(attr_.type), + host_view.extent(0) * host_view.extent(1) * sizeof(T)); // buffer allocations if (mem_strategy_ == memory_strategy::copy_device) { @@ -405,14 +398,14 @@ class batched_device_view_from_host { } } catch (std::bad_alloc& e) { RAFT_LOG_DEBUG("Insufficient memory for device buffers"); - if (attr.devicePointer != nullptr) { + if (attr_.devicePointer != nullptr) { mem_strategy_ = memory_strategy::managed_only; } else { throw std::bad_alloc(); } } catch (raft::logic_error& e) { RAFT_LOG_DEBUG("Insufficient memory for device buffers (logic error)"); - if (attr.devicePointer != nullptr) { + if (attr_.devicePointer != nullptr) { mem_strategy_ = memory_strategy::managed_only; } else { throw raft::logic_error("Insufficient memory for device buffers (logic error)"); @@ -420,6 +413,17 @@ class batched_device_view_from_host { } } + // setup streams + if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && + raft::resource::get_stream_pool_size(res) >= 1) { + prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); + writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); + } else { + local_stream_pool_ = std::make_shared(2); + prefetch_stream_ = local_stream_pool_.value()->get_stream(); + writeback_stream_ = local_stream_pool_.value()->get_stream(); + } + // if data is managed and not for_write_ we can set the attribute on the device ptr if (mem_strategy_ == memory_strategy::managed_only) { // location_.type = CU_MEM_LOCATION_TYPE_DEVICE; @@ -621,18 +625,29 @@ class batched_device_view_from_host { void prefetch_from_host_to_device(T* dev_ptr, size_t src_row_offset, size_t num_rows) { - raft::copy(dev_ptr, - host_view_.data_handle() + src_row_offset * host_view_.extent(1), - num_rows * host_view_.extent(1), - prefetch_stream_); + const size_t n_elem = num_rows * host_view_.extent(1); + const size_t n_bytes = n_elem * sizeof(T); + RAFT_CUDA_TRY(cudaHostRegister(host_view_.data_handle() + src_row_offset * host_view_.extent(1), + n_bytes, + cudaHostRegisterDefault)); + // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory + RAFT_CUDA_TRY(cudaMemcpyAsync(dev_ptr, + host_view_.data_handle() + src_row_offset * host_view_.extent(1), + n_bytes, + cudaMemcpyHostToDevice, + prefetch_stream_)); } void writeback_from_device_to_host(T* dev_ptr, size_t dst_row_offset, size_t num_rows) { - raft::copy(host_view_.data_handle() + dst_row_offset * host_view_.extent(1), - dev_ptr, - num_rows * host_view_.extent(1), - writeback_stream_); + const size_t n_elem = num_rows * host_view_.extent(1); + const size_t n_bytes = n_elem * sizeof(T); + // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory + RAFT_CUDA_TRY(cudaMemcpyAsync(host_view_.data_handle() + dst_row_offset * host_view_.extent(1), + dev_ptr, + n_bytes, + cudaMemcpyDeviceToHost, + writeback_stream_)); } // stream pool for local streams @@ -655,6 +670,7 @@ class batched_device_view_from_host { // input pointer information raft::host_matrix_view host_view_; + cudaPointerAttributes attr_; // internal device buffers uint64_t num_buffers_; From ec45fd251d90cd8713c58252d8258ebee3b700a8 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 10 Mar 2026 22:46:18 +0000 Subject: [PATCH 18/22] add test, persist stream pool, cleanup --- cpp/src/neighbors/detail/cagra/utils.hpp | 214 ++++++++++-------- cpp/tests/CMakeLists.txt | 1 + .../test_batched_device_view_from_host.cu | 205 +++++++++++++++++ 3 files changed, 326 insertions(+), 94 deletions(-) create mode 100644 cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 75883a9636..44d87d2993 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -322,15 +322,22 @@ void copy_with_padding( * writeback of the data Each batch can be referenced exactlyonce by calling the next_view() * function * + * Usage: + * ``` + * batched_device_view_from_host view(res, host_view, batch_size, host_writeback, + * initialize); while (view.next_view().extent(0) > 0) { auto device_view = view.next_view(); + * // use device_view + * } + * ``` + * + * The call to next_view() will + * * synchronize on all previous operations / increments batch_id_ + * * (optionally) write back the data of the previous batch to the host + * * (optionally) prefetch the data of the next batch + * * return the view of the current batch + * * @tparam T The type of the data * @tparam IdxT The type of the index - * @param res The resources - * @param host_view The host view to create the batched device view from - * @param batch_size The batch size - * @param host_writeback Whether to write back the data to the host (only for host memory) - * @param initialize Whether to initialize the data (only for managed memory) - * - * @return The batched device view */ template class batched_device_view_from_host { @@ -341,6 +348,18 @@ class batched_device_view_from_host { managed_only, // data is on managed memory (system managed) }; + /** + * Create a batched device view from a host view and will handle the prefetch and + * writeback of the data. Each batch can be referenced exactly once by calling the next_view() + * method. + * + * @param res The resources to use + * @param host_view The host view to create the batched device view from + * @param batch_size The batch size + * @param host_writeback Whether to write back the data to the host (only for host memory) + * (default: false) + * @param initialize Whether to initialize the data (only for managed memory) (default: true) + */ batched_device_view_from_host(raft::resources const& res, raft::host_matrix_view host_view, uint64_t batch_size, @@ -360,6 +379,9 @@ class batched_device_view_from_host { return; } + RAFT_EXPECTS(host_writeback_ || initialize_, + "At least one of host_writeback or initialize must be true"); + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr_, host_view.data_handle())); switch (attr_.type) { case cudaMemoryTypeUnregistered: @@ -388,7 +410,8 @@ class batched_device_view_from_host { raft::make_extents(batch_size, host_view.extent(1)))); device_ptr[1] = device_mem_[1]->data_handle(); } - if (host_writeback_ && batch_size * 2 < static_cast(host_view.extent(0))) { + if (host_writeback_ && initialize_ && + batch_size * 2 < static_cast(host_view.extent(0))) { num_buffers_ = 3; device_mem_[2].emplace(raft::make_device_mdarray( res, @@ -397,15 +420,16 @@ class batched_device_view_from_host { device_ptr[2] = device_mem_[2]->data_handle(); } } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for device buffers"); if (attr_.devicePointer != nullptr) { + RAFT_LOG_DEBUG("Insufficient memory for device buffers, switching to managed memory"); mem_strategy_ = memory_strategy::managed_only; } else { throw std::bad_alloc(); } } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for device buffers (logic error)"); if (attr_.devicePointer != nullptr) { + RAFT_LOG_DEBUG( + "Insufficient memory for device buffers (logic error), switching to managed memory"); mem_strategy_ = memory_strategy::managed_only; } else { throw raft::logic_error("Insufficient memory for device buffers (logic error)"); @@ -413,20 +437,18 @@ class batched_device_view_from_host { } } - // setup streams - if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && - raft::resource::get_stream_pool_size(res) >= 1) { - prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); - writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); - } else { - local_stream_pool_ = std::make_shared(2); - prefetch_stream_ = local_stream_pool_.value()->get_stream(); - writeback_stream_ = local_stream_pool_.value()->get_stream(); + // setup stream pool if not already present + size_t required_streams = host_writeback_ && initialize_ ? 2 : 1; + if (!res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) || + raft::resource::get_stream_pool_size(res) < required_streams) { + // always create at least 2 streams to account for subsequent iterator calls + raft::resource::set_cuda_stream_pool(res, std::make_shared(2)); } + prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); + writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); // if data is managed and not for_write_ we can set the attribute on the device ptr if (mem_strategy_ == memory_strategy::managed_only) { - // location_.type = CU_MEM_LOCATION_TYPE_DEVICE; location_.type = cudaMemLocationTypeDevice; location_.id = static_cast(raft::resource::get_device_id(res_)); if (!host_writeback_) { @@ -440,6 +462,84 @@ class batched_device_view_from_host { prefetch_next_batch(); } + ~batched_device_view_from_host() noexcept + { + raft::resource::sync_stream(res_); + + // if data is on host and for_write --> make sure to copy back last active + // if data is managed and evict --> evict last active + + // make sure to sync on prefetch stream & res + switch (mem_strategy_) { + case memory_strategy::managed_only: + if (!host_writeback_) { + uint32_t discard_pos = batch_id_ % num_buffers_; + size_t discard_size_rows = actual_batch_size_[discard_pos]; + if (batch_id_ > 0) { + discard_pos = (batch_id_ - 1) % num_buffers_; + discard_size_rows += batch_size_; + } + discard_managed_region(device_ptr[discard_pos], + discard_size_rows * host_view_.extent(1) * sizeof(T)); + writeback_stream_.synchronize(); + } + break; + case memory_strategy::copy_device: + if (host_writeback_) { + uint32_t writeback_pos_last = batch_id_ % num_buffers_; + if (batch_id_ > 0) { + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); + } + { + uint64_t writeback_offset_last = batch_id_ * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos_last], + writeback_offset_last, + actual_batch_size_[writeback_pos_last]); + } + writeback_stream_.synchronize(); + } + break; + case memory_strategy::device_only: break; + } + } + + /** + * Returns the next view of the batch + * + * This function will ensure the next batch is ready and will trigger the prefetch of the + * subsequent next batch. If writeback is enabled, the last active batch will be written back to + * the host. + * + * @return The next view of the batch + */ + raft::device_matrix_view next_view() + { + bool end_of_data = static_cast((batch_id_ + 1) * batch_size_) >= + static_cast(host_view_.extent(0)); + + // special case for empty host view or last batch surpassed + if (end_of_data) { + return raft::make_device_matrix_view(nullptr, 0, host_view_.extent(1)); + } + + // trigger prefetch of next batch (also increments batch_id_) + prefetch_next_batch(); + + uint32_t current_pos = batch_id_ % num_buffers_; + return raft::make_device_matrix_view( + device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); + } + + private: + /** + * Prefetch the next batch + * + * This function will prefetch the next batch and will handle the writeback of the data. + * + * @return True if the next batch exists, false otherwise + */ bool prefetch_next_batch() { batch_id_++; @@ -512,77 +612,6 @@ class batched_device_view_from_host { return next_batch_exists; } - ~batched_device_view_from_host() noexcept - { - prefetch_stream_.synchronize(); - writeback_stream_.synchronize(); - raft::resource::sync_stream(res_); - - // if data is on host and for_write --> make sure to copy back last active - // if data is managed and evict --> evict last active - - // make sure to sync on prefetch stream & res - switch (mem_strategy_) { - case memory_strategy::managed_only: - if (!host_writeback_) { - uint32_t discard_pos = batch_id_ % num_buffers_; - size_t discard_size_rows = actual_batch_size_[discard_pos]; - if (batch_id_ > 0) { - discard_pos = (batch_id_ - 1) % num_buffers_; - discard_size_rows += batch_size_; - } - discard_managed_region(device_ptr[discard_pos], - discard_size_rows * host_view_.extent(1) * sizeof(T)); - } - writeback_stream_.synchronize(); - break; - case memory_strategy::copy_device: - if (host_writeback_) { - uint32_t writeback_pos_last = batch_id_ % num_buffers_; - if (batch_id_ > 0) { - uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; - uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; - writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); - } - { - uint64_t writeback_offset_last = batch_id_ * batch_size_; - writeback_from_device_to_host(device_ptr[writeback_pos_last], - writeback_offset_last, - actual_batch_size_[writeback_pos_last]); - } - } - writeback_stream_.synchronize(); - break; - case memory_strategy::device_only: break; - } - } - - /** - * Returns the next view of the batch - * - * This function will ensure the next batch is ready and will trigger the prefetch of the - * subsequent next batch - * - * @return The next view of the batch - */ - raft::device_matrix_view next_view() - { - // special case for empty host view - if (host_view_.extent(0) == 0) { - return raft::make_device_matrix_view(nullptr, 0, host_view_.extent(1)); - } - - // trigger prefetch of next batch - bool next_batch_exists = prefetch_next_batch(); - - RAFT_EXPECTS(batch_id_ * batch_size_ < host_view_.extent(0), "Batch index out of bounds"); - - uint32_t current_pos = batch_id_ % num_buffers_; - return raft::make_device_matrix_view( - device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); - } - - private: void advise_read_mostly(T* ptr, size_t size) { #if CUDA_VERSION >= 13000 @@ -627,9 +656,6 @@ class batched_device_view_from_host { { const size_t n_elem = num_rows * host_view_.extent(1); const size_t n_bytes = n_elem * sizeof(T); - RAFT_CUDA_TRY(cudaHostRegister(host_view_.data_handle() + src_row_offset * host_view_.extent(1), - n_bytes, - cudaHostRegisterDefault)); // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory RAFT_CUDA_TRY(cudaMemcpyAsync(dev_ptr, host_view_.data_handle() + src_row_offset * host_view_.extent(1), diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 35794adf9b..77fd18c7d3 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -173,6 +173,7 @@ ConfigureTest( ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_HELPERS_TEST PATH neighbors/ann_cagra/test_optimize_uint32_t.cu + neighbors/ann_cagra/test_batched_device_view_from_host.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu b/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu new file mode 100644 index 0000000000..1e1cc13093 --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu @@ -0,0 +1,205 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../../src/neighbors/detail/cagra/utils.hpp" + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra { + +using IdxT = uint32_t; + +struct BatchConfig { + bool initialize; + bool host_writeback; +}; + +struct DimsConfig { + int64_t n_rows; + int64_t n_cols; + uint64_t batch_size; +}; + +class BatchedDeviceViewFromHostTest : public ::testing::Test { + protected: + void SetUp() override { raft::resource::sync_stream(res); } + + /** + * Run batched_device_view_from_host over host data, copy device views back, + * and verify against the input. + */ + template + void run_and_verify_batched(InputMatrixView input_view, + uint64_t batch_size, + bool host_writeback, + bool initialize) + { + int64_t n_rows = input_view.extent(0); + int64_t n_cols = input_view.extent(1); + + std::vector readback(n_rows * n_cols); + + int64_t total_processed = 0; + + { + cagra::detail::batched_device_view_from_host batched( + res, + raft::make_host_matrix_view(input_view.data_handle(), n_rows, n_cols), + batch_size, + host_writeback, + initialize); + while (true) { + auto dev_view = batched.next_view(); + if (dev_view.extent(0) == 0) break; + + if (initialize) { + raft::copy(readback.data() + total_processed * n_cols, + dev_view.data_handle(), + dev_view.extent(0) * dev_view.extent(1), + raft::resource::get_cuda_stream(res)); + } + if (host_writeback) { raft::matrix::fill(res, dev_view, IdxT(17)); } + total_processed += dev_view.extent(0); + } + } + raft::resource::sync_stream(res); + + EXPECT_EQ(total_processed, n_rows); + if (initialize) { + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(13)) << "Mismatch (initialize) at index " << i; + } + } + if (host_writeback) { + auto readback_view = + raft::make_host_matrix_view(readback.data(), n_rows, n_cols); + raft::copy(res, readback_view, input_view); + raft::resource::sync_stream(res); + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(17)) << "Mismatch (host_writeback) at index " << i; + } + } + } + + raft::resources res; +}; + +TEST_F(BatchedDeviceViewFromHostTest, EmptyView) +{ + auto host_empty = raft::make_host_matrix(0, 8); + auto host_view = host_empty.view(); + cagra::detail::batched_device_view_from_host batched( + res, host_view, /*batch_size=*/128, /*host_writeback=*/false, /*initialize=*/true); + + auto view = batched.next_view(); + EXPECT_EQ(view.extent(0), 0); + EXPECT_EQ(view.extent(1), 8); + EXPECT_EQ(view.data_handle(), nullptr); +} + +using BatchDimsParam = std::tuple; + +class BatchedDeviceViewFromHostParameterizedTest + : public BatchedDeviceViewFromHostTest, + public ::testing::WithParamInterface {}; + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, VectorHostData) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + std::vector host_data(n_rows * n_cols); + auto host_view = raft::make_host_matrix_view(host_data.data(), n_rows, n_cols); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, PinnedMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_pinned_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, ManagedMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_managed_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, DeviceMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_device_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + raft::matrix::fill(res, host_view, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +static const std::array kBatchConfigs = {{ + {/*initialize=*/true, /*host_writeback=*/false}, + {/*initialize=*/false, /*host_writeback=*/true}, + {/*initialize=*/true, /*host_writeback=*/true}, +}}; + +static const std::array kDimsConfigs = {{ + {/*n_rows=*/64, /*n_cols=*/32, /*batch_size=*/256}, // rows less than batch size, single batch + {/*n_rows=*/64, /*n_cols=*/32, /*batch_size=*/64}, // single batch + {/*n_rows=*/256, /*n_cols=*/32, /*batch_size=*/32}, // multiple batches + {/*n_rows=*/500, + /*n_cols=*/32, + /*batch_size=*/128}, // multiple batches, partial batch in the end +}}; + +INSTANTIATE_TEST_SUITE_P(BatchConfigs, + BatchedDeviceViewFromHostParameterizedTest, + ::testing::Combine(::testing::ValuesIn(kBatchConfigs), + ::testing::ValuesIn(kDimsConfigs))); + +} // namespace cuvs::neighbors::cagra From c412138a0dd6e3b81fa9bc4e10a1b546d71c5476 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 11 Mar 2026 00:04:52 +0000 Subject: [PATCH 19/22] switch to cooperative groups as __reduce_min_sync causes issues --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index b5e055820d..2444350253 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -29,11 +29,16 @@ #include #include +#include +#include + #include #include #include #include +namespace cg = cooperative_groups; + namespace cuvs::neighbors::cagra::detail::graph { // unnamed namespace to avoid multiple definition error @@ -196,6 +201,9 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ { extern __shared__ unsigned char smem_buf[]; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + const uint32_t wid = threadIdx.x / raft::WarpSize; const uint32_t lane_id = threadIdx.x % raft::WarpSize; @@ -207,8 +215,7 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ uint64_t* const num_retain = stats; uint64_t* const num_full = stats + 1; - const unsigned warp_mask = 0xffffffff; - const uint32_t maxval16 = 0x0000ffff; + const uint32_t maxval16 = 0x0000ffff; const uint64_t nid_batch = blockIdx.x * num_warps + wid; const uint64_t nid = nid_batch + (batch_size * batch_id); @@ -255,11 +262,7 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ __syncwarp(); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); + num_edges_no_detour = cg::reduce(warp, num_edges_no_detour, cg::plus()); num_edges_no_detour = min(num_edges_no_detour, output_graph_degree); if (lane_id == 0) { @@ -280,7 +283,7 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ } uint32_t local_min_with_tag = (local_min << 16) | ((uint32_t)local_idx); - uint32_t warp_min_with_tag = __reduce_min_sync(warp_mask, local_min_with_tag); + uint32_t warp_min_with_tag = cg::reduce(warp, local_min_with_tag, cg::less()); uint32_t warp_min_count = warp_min_with_tag >> 16; uint32_t warp_local_idx = warp_min_with_tag & 0xffff; @@ -294,7 +297,7 @@ __global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_ for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { if (smem_indices[k] == selected_node) { smem_num_detour[k] = maxval16; } } - __syncwarp(warp_mask); + __syncwarp(); if (lane_id == 0) { output_graph_ptr[nid_batch * output_graph_degree + i] = selected_node; } } @@ -312,7 +315,10 @@ __device__ unsigned int warp_pos_in_array(T val, const T* array, uint64_t num) break; } } - ret = __reduce_min_sync(0xffffffff, ret); + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + ret = cg::reduce(warp, ret, cg::less()); return ret; } From ab01bab594e4337a9b6530a686e0d8642ce61866 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 13 Mar 2026 18:43:55 +0000 Subject: [PATCH 20/22] back to column wise reverse graph creation to boost closer connections --- cpp/src/neighbors/detail/cagra/graph_core.cuh | 124 +++++++++++------- 1 file changed, 78 insertions(+), 46 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 88c13c139e..5d43da851b 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -171,24 +171,38 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, } } +template +__global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] + IdxT* const rev_graph, // [size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree) +{ + const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint32_t tnum = blockDim.x * gridDim.x; + + for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { + const IdxT dest_id = dest_nodes[src_id]; + if (dest_id >= graph_size) continue; + + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } + } +} + template -__global__ void kern_rev_graph_batched(const IdxT* const dest_nodes, // [batch_size, degree] - IdxT* const rev_graph, // [graph_size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree, - const uint32_t batch_size, - const uint32_t batch_id) +__global__ void kern_make_rev_graph_k(const IdxT* const output_graph, // [graph_size, degree] + IdxT* const rev_graph, // [graph_size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree, + uint64_t k) { const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); const uint64_t tnum = blockDim.x * gridDim.x; - const uint64_t block_batch_size = min(batch_size, graph_size - batch_id * batch_size); - - for (uint64_t idx = tid; idx < block_batch_size * degree; idx += tnum) { - const IdxT dest_id = dest_nodes[idx]; - const uint32_t src_id = idx / degree; - + for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { + IdxT dest_id = output_graph[k + (degree * src_id)]; if (dest_id >= graph_size) continue; const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); @@ -840,50 +854,67 @@ void make_reverse_graph_gpu(raft::resources const& res, uint64_t output_graph_degree) { raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse"); + "cagra::graph::optimize/reverse2"); - auto default_ws_mr = raft::resource::get_workspace_resource(res); + auto d_rev_graph = + raft::make_device_vector_view(d_rev_graph_ptr, graph_size * output_graph_degree); + auto d_rev_graph_count = + raft::make_device_vector_view(d_rev_graph_count_ptr, graph_size); - raft::matrix::fill( - res, - raft::make_device_vector_view(d_rev_graph_ptr, graph_size * output_graph_degree), - IdxT(-1)); + // + // Make reverse graph + // + const double time_make_start = cur_time(); - raft::matrix::fill( - res, - raft::make_device_vector_view(d_rev_graph_count_ptr, graph_size), - uint32_t(0)); + raft::matrix::fill(res, d_rev_graph, IdxT(-1)); + raft::matrix::fill(res, d_rev_graph_count, uint32_t(0)); - const uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + if (is_ptr_host_accessible(output_graph_ptr)) { + auto d_dest_nodes = + raft::make_device_mdarray(res, raft::make_extents(graph_size)); - batched_device_view_from_host d_output_graph( - res, - raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), - /*batch_size*/ batch_size, - /*host_writeback*/ false, - /*initialize*/ true); - - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + RAFT_CUDA_TRY(cudaMemcpy2DAsync(d_dest_nodes.data_handle(), + sizeof(IdxT), + output_graph_ptr + k, + output_graph_degree * sizeof(IdxT), + 1 * sizeof(IdxT), + graph_size, + cudaMemcpyHostToDevice, + raft::resource::get_cuda_stream(res))); + + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); + } + } else { + // output graph is fully device accessible, so we need no copy to device dim3 threads(256, 1, 1); dim3 blocks(1024, 1, 1); - auto output_view = d_output_graph.next_view(); - - kern_rev_graph_batched<<>>( - output_view.data_handle(), - d_rev_graph_ptr, - d_rev_graph_count_ptr, - static_cast(graph_size), - static_cast(output_graph_degree), - static_cast(batch_size), - static_cast(i_batch)); + for (uint64_t k = 0; k < output_graph_degree; k++) { + kern_make_rev_graph_k<<>>( + output_graph_ptr, + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree, + k); + } } raft::resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); + + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", + (time_make_end - time_make_start) * 1000.0); } -} // namespace template void optimize(raft::resources const& res, InputMatrixView knn_graph, From 68f78839a5437d48f54d876b484efded20e8448d Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 13 Mar 2026 20:00:55 +0000 Subject: [PATCH 21/22] fix signness --- cpp/src/neighbors/detail/cagra/utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 44d87d2993..7dae487863 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -561,7 +561,7 @@ class batched_device_view_from_host { // if data is on device only this is almost a noop, just prepping the pointers - RAFT_EXPECTS(offset_ <= host_view_.extent(0), "Offset out of bounds"); + RAFT_EXPECTS(static_cast(offset_) <= host_view_.extent(0), "Offset out of bounds"); bool next_batch_exists = offset_ < static_cast(host_view_.extent(0)); From add206a7697aaf019543a43e39f763858992c5a2 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 13 Mar 2026 22:51:05 +0000 Subject: [PATCH 22/22] stupid me trusting cursor to fix this --- cpp/src/neighbors/detail/cagra/utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 7dae487863..79d1ed1cae 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -561,7 +561,7 @@ class batched_device_view_from_host { // if data is on device only this is almost a noop, just prepping the pointers - RAFT_EXPECTS(static_cast(offset_) <= host_view_.extent(0), "Offset out of bounds"); + RAFT_EXPECTS(static_cast(offset_) <= host_view_.extent(0), "Offset out of bounds"); bool next_batch_exists = offset_ < static_cast(host_view_.extent(0));