diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 9d3198eb74..34804ea9a9 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -833,27 +833,26 @@ inline std::pair optimize_workspace_size(size_t n_rows, // Prune stage memory // We neglect 8 bytes (both on host and device) for stats - size_t prune_host = n_rows * intermediate_degree * sizeof(uint8_t); // detour count + size_t batch_size = std::min(static_cast(256 * 1024), n_rows); - size_t prune_dev = n_rows * intermediate_degree * 1; // detour count (uint8_t) - prune_dev += n_rows * sizeof(uint32_t); // d_num_detour_edges - prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph + size_t prune_dev = batch_size * intermediate_degree * 1; // detour count (uint8_t) + prune_dev += batch_size * sizeof(uint32_t); // d_num_detour_edges + prune_dev += n_rows * intermediate_degree * index_size; // d_input_graph // Reverse graph stage memory - size_t rev_host = n_rows * graph_degree * index_size; // rev_graph - rev_host += n_rows * sizeof(uint32_t); // rev_graph_count - rev_host += n_rows * index_size; // dest_nodes - size_t rev_dev = n_rows * graph_degree * index_size; // d_rev_graph rev_dev += n_rows * sizeof(uint32_t); // d_rev_graph_count rev_dev += n_rows * sizeof(uint32_t); // d_dest_nodes - // Memory for merging graphs (host only) + // Memory for merging graphs (host only optional) size_t combine_host = n_rows * sizeof(uint32_t) + graph_degree * sizeof(uint32_t); // in_edge_count + hist - size_t total_host = mst_host + std::max({prune_host, rev_host, combine_host}); - size_t total_dev = std::max(prune_dev, rev_dev); + // additional memory for combine stage on device + size_t combine_dev = n_rows * graph_degree * index_size; // d_output_graph + + size_t total_host = mst_host + combine_host; + size_t total_dev = std::max(prune_dev, rev_dev + combine_dev); return std::make_pair(total_host, total_dev); } diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index d94e279829..5d43da851b 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -22,17 +22,23 @@ #include #include +#include #include #include #include +#include +#include + #include #include #include #include +namespace cg = cooperative_groups; + namespace cuvs::neighbors::cagra::detail::graph { // unnamed namespace to avoid multiple definition error @@ -165,42 +171,100 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, } } -template -__global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - const uint32_t graph_size, - const uint32_t graph_degree, - const uint32_t degree, - const uint32_t batch_size, - const uint32_t batch_id, - uint8_t* const detour_count, // [graph_chunk_size, graph_degree] - uint32_t* const num_no_detour_edges, // [graph_size] - uint64_t* const stats) +template +__global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] + IdxT* const rev_graph, // [size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree) { - __shared__ uint32_t smem_num_detour[MAX_DEGREE]; + const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint32_t tnum = blockDim.x * gridDim.x; + + for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { + const IdxT dest_id = dest_nodes[src_id]; + if (dest_id >= graph_size) continue; + + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } + } +} + +template +__global__ void kern_make_rev_graph_k(const IdxT* const output_graph, // [graph_size, degree] + IdxT* const rev_graph, // [graph_size, degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t degree, + uint64_t k) +{ + const uint64_t tid = threadIdx.x + (blockDim.x * blockIdx.x); + const uint64_t tnum = blockDim.x * gridDim.x; + + for (uint64_t src_id = tid; src_id < graph_size; src_id += tnum) { + IdxT dest_id = output_graph[k + (degree * src_id)]; + if (dest_id >= graph_size) continue; + + const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); + if (pos < degree) { rev_graph[(degree * dest_id) + pos] = static_cast(src_id); } + } +} + +template +__global__ void kern_fused_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] + IdxT* const output_graph_ptr, // [batch_size, output_graph_degree] + const uint32_t graph_size, + const uint32_t knn_graph_degree, + const uint32_t output_graph_degree, + const uint32_t batch_size, + const uint32_t batch_id, + uint32_t* const d_invalid_neighbor_list, + uint64_t* const stats) +{ + extern __shared__ unsigned char smem_buf[]; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + + const uint32_t wid = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; + + IdxT* const smem_indices = + reinterpret_cast(smem_buf + wid * knn_graph_degree * sizeof(IdxT)); + uint32_t* const smem_num_detour = reinterpret_cast( + smem_buf + wid * knn_graph_degree * sizeof(IdxT) + num_warps * knn_graph_degree * sizeof(IdxT)); + uint64_t* const num_retain = stats; uint64_t* const num_full = stats + 1; - const uint64_t iA = blockIdx.x + (batch_size * batch_id); - if (iA >= graph_size) { return; } - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { + const uint32_t maxval16 = 0x0000ffff; + + const uint64_t nid_batch = blockIdx.x * num_warps + wid; + const uint64_t nid = nid_batch + (batch_size * batch_id); + + if (nid >= graph_size) { return; } + + // Load this node's neighbor row into shared memory to reduce global reads + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { smem_num_detour[k] = 0; - if (knn_graph[k + ((uint64_t)graph_degree * iA)] == iA) { + smem_indices[k] = knn_graph[k + ((uint64_t)knn_graph_degree * nid)]; + if (smem_indices[k] == nid) { // Lower the priority of self-edge - smem_num_detour[k] = graph_degree; + smem_num_detour[k] = knn_graph_degree; } } - __syncthreads(); + __syncwarp(); // count number of detours (A->D->B) - for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { - const uint64_t iD = knn_graph[kAD + (graph_degree * iA)]; + for (uint32_t kAD = 0; kAD < knn_graph_degree - 1; kAD++) { + const uint64_t iD = smem_indices[kAD]; if (iD >= graph_size) { continue; } - for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { - const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; - for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { + for (uint32_t kDB = lane_id; kDB < knn_graph_degree; kDB += raft::WarpSize) { + const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)knn_graph_degree * iD)]; + for (uint32_t kAB = kAD + 1; kAB < knn_graph_degree; kAB++) { // if ( kDB < kAB ) { - const uint64_t iB = knn_graph[kAB + (graph_degree * iA)]; + const uint64_t iB = smem_indices[kAB]; if (iB == iB_candidate) { atomicAdd(smem_num_detour + kAB, 1); break; @@ -208,44 +272,174 @@ __global__ void kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, g } } } - __syncthreads(); + __syncwarp(); } uint32_t num_edges_no_detour = 0; - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - detour_count[k + (graph_degree * iA)] = min(smem_num_detour[k], (uint32_t)255); + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + smem_num_detour[k] = min(smem_num_detour[k], maxval16); if (smem_num_detour[k] == 0) { num_edges_no_detour++; } + if (smem_indices[k] >= graph_size) { smem_num_detour[k] = maxval16; } } - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); - num_edges_no_detour = min(num_edges_no_detour, degree); - if (threadIdx.x == 0) { - num_no_detour_edges[iA] = num_edges_no_detour; + __syncwarp(); + + num_edges_no_detour = cg::reduce(warp, num_edges_no_detour, cg::plus()); + num_edges_no_detour = min(num_edges_no_detour, output_graph_degree); + + if (lane_id == 0) { atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); - if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } + if (num_edges_no_detour >= output_graph_degree) { + atomicAdd((unsigned long long int*)num_full, 1); + } + } + + for (uint32_t i = 0; i < output_graph_degree; i++) { + uint32_t local_min = maxval16; + uint32_t local_idx = maxval16; + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + if (smem_num_detour[k] < local_min) { + local_min = smem_num_detour[k]; + local_idx = k; + } + } + + uint32_t local_min_with_tag = (local_min << 16) | ((uint32_t)local_idx); + uint32_t warp_min_with_tag = cg::reduce(warp, local_min_with_tag, cg::less()); + uint32_t warp_min_count = warp_min_with_tag >> 16; + uint32_t warp_local_idx = warp_min_with_tag & 0xffff; + + if (warp_min_count == maxval16 || warp_local_idx == maxval16) { + if (lane_id == 0) { atomicExch(d_invalid_neighbor_list, 1u); } + break; + } + + IdxT selected_node = smem_indices[warp_local_idx]; + + for (uint32_t k = lane_id; k < knn_graph_degree; k += raft::WarpSize) { + if (smem_indices[k] == selected_node) { smem_num_detour[k] = maxval16; } + } + __syncwarp(); + + if (lane_id == 0) { output_graph_ptr[nid_batch * output_graph_degree + i] = selected_node; } } } -template -__global__ void kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] - IdxT* const rev_graph, // [size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree) +// Helper functions for merging the graph +template +__device__ unsigned int warp_pos_in_array(T val, const T* array, uint64_t num) { - const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); - const uint32_t tnum = blockDim.x * gridDim.x; + unsigned int ret = num; + const uint32_t lane_id = threadIdx.x % 32; + for (uint64_t i = lane_id; i < num; i += 32) { + if (val == array[i]) { + ret = i; + break; + } + } - for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { - const IdxT dest_id = dest_nodes[src_id]; - if (dest_id >= graph_size) continue; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + ret = cg::reduce(warp, ret, cg::less()); + return ret; +} - const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } +template +__device__ void thread_shift_array(T* array, uint64_t num) +{ + for (uint64_t i = num; i > 0; i--) { + array[i] = array[i - 1]; + } +} + +template +__global__ void kern_merge_graph(IdxT* output_graph, // [batch_size, output_graph_degree] + const IdxT* const rev_graph, // [graph_size, output_graph_degree] + uint32_t* const rev_graph_count, // [graph_size] + const uint32_t graph_size, + const uint32_t output_graph_degree, + const IdxT* const mst_graph, // [batch_size, output_graph_degree] + const uint32_t mst_graph_degree, + const uint32_t* const mst_graph_num_edges_ptr, // [batch_size] + const uint32_t batch_size, + const uint32_t batch_id, + bool guarantee_connectivity, + bool* check_num_protected_edges) +{ + extern __shared__ unsigned char smem_buf[]; + + const uint32_t wid = threadIdx.x / raft::WarpSize; + const uint32_t lane_id = threadIdx.x % raft::WarpSize; + + IdxT* smem_sorted_output_graph = + reinterpret_cast(smem_buf + wid * output_graph_degree * sizeof(IdxT)); + + const uint64_t nid_batch = blockIdx.x * num_warps + wid; + const uint64_t nid = nid_batch + (batch_size * batch_id); + + if (nid >= graph_size) { return; } + + const auto mst_graph_num_edges = guarantee_connectivity ? mst_graph_num_edges_ptr[nid_batch] : 0; + // If guarantee_connectivity == true, use a temporal list to merge the + // neighbor lists of the graphs. + if (guarantee_connectivity) { + for (uint32_t i = lane_id; i < mst_graph_degree; i += raft::WarpSize) { + smem_sorted_output_graph[i] = mst_graph[nid_batch * mst_graph_degree + i]; + } + __syncwarp(); + for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; + (pruned_j < output_graph_degree) && (output_j < output_graph_degree); + pruned_j++) { + const auto v = output_graph[output_graph_degree * nid_batch + pruned_j]; + unsigned int dup = 0; + for (uint32_t m = lane_id; m < output_j; m += raft::WarpSize) { + if (v == smem_sorted_output_graph[m]) { + dup = 1; + break; + } + } + + unsigned int warp_dup = __ballot_sync(0xffffffff, dup); + if (warp_dup == 0) { + if (lane_id == 0) smem_sorted_output_graph[output_j] = v; + output_j++; + } + __syncwarp(); + } + } + + else { + for (uint32_t i = lane_id; i < output_graph_degree; i += raft::WarpSize) { + smem_sorted_output_graph[i] = output_graph[output_graph_degree * nid_batch + i]; + } + __syncwarp(); + } + + const auto num_protected_edges = max(mst_graph_num_edges, output_graph_degree / 2); + + if (num_protected_edges > output_graph_degree) { check_num_protected_edges[0] = false; } + if (num_protected_edges == output_graph_degree) { return; } + + auto kr = min(rev_graph_count[nid], output_graph_degree); + + while (kr) { + kr -= 1; + if (rev_graph[kr + (output_graph_degree * nid)] < graph_size) { + uint64_t pos = warp_pos_in_array( + rev_graph[kr + (output_graph_degree * nid)], smem_sorted_output_graph, output_graph_degree); + if (pos < num_protected_edges) { continue; } + uint64_t num_shift = pos - num_protected_edges; + if (pos >= output_graph_degree) { num_shift = output_graph_degree - num_protected_edges - 1; } + if (lane_id == 0) { + thread_shift_array(smem_sorted_output_graph + num_protected_edges, num_shift); + smem_sorted_output_graph[num_protected_edges] = rev_graph[kr + (output_graph_degree * nid)]; + } + __syncwarp(); + } + } + + for (uint32_t i = lane_id; i < output_graph_degree; i += raft::WarpSize) { + output_graph[(output_graph_degree * nid_batch) + i] = smem_sorted_output_graph[i]; } } @@ -482,23 +676,245 @@ __global__ void kern_mst_opt_postprocessing(IdxT* outgoing_num_edges, // [graph } } -template -uint64_t pos_in_array(T val, const T* array, uint64_t num) +template +void log_incoming_edges_histogram(const IdxT* output_graph_ptr, + uint64_t graph_size, + uint64_t output_graph_degree) { - for (uint64_t i = 0; i < num; i++) { - if (val == array[i]) { return i; } + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/check_edges"); + auto in_edge_count = raft::make_host_vector(graph_size); + auto in_edge_count_ptr = in_edge_count.data_handle(); +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + in_edge_count_ptr[i] = 0; + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + for (uint64_t k = 0; k < output_graph_degree; k++) { + const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; + if (j >= graph_size) continue; +#pragma omp atomic + in_edge_count_ptr[j] += 1; + } + } + auto hist = raft::make_host_vector(output_graph_degree); + auto hist_ptr = hist.data_handle(); + for (uint64_t k = 0; k < output_graph_degree; k++) { + hist_ptr[k] = 0; + } +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + uint32_t count = in_edge_count_ptr[i]; + if (count >= output_graph_degree) continue; +#pragma omp atomic + hist_ptr[count] += 1; + } + RAFT_LOG_DEBUG("# Histogram for number of incoming edges\n"); + uint32_t sum_hist = 0; + for (uint64_t k = 0; k < output_graph_degree; k++) { + sum_hist += hist_ptr[k]; + RAFT_LOG_DEBUG("# %3lu, %8u, %lf, (%8u, %lf)\n", + k, + hist_ptr[k], + (double)hist_ptr[k] / graph_size, + sum_hist, + (double)sum_hist / graph_size); } - return num; } -template -void shift_array(T* array, uint64_t num) +template +void check_duplicates_and_out_of_range(const IdxT* output_graph_ptr, + uint64_t graph_size, + uint64_t output_graph_degree) { - for (uint64_t i = num; i > 0; i--) { - array[i] = array[i - 1]; + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/check_duplicates"); + uint64_t num_dup = 0; + uint64_t num_oor = 0; +#pragma omp parallel for reduction(+ : num_dup) reduction(+ : num_oor) + for (uint64_t i = 0; i < graph_size; i++) { + auto my_out_graph = output_graph_ptr + (output_graph_degree * i); + for (uint32_t j = 0; j < output_graph_degree; j++) { + const auto neighbor_a = my_out_graph[j]; + + if (neighbor_a > graph_size) { + num_oor++; + continue; + } + + for (uint32_t k = j + 1; k < output_graph_degree; k++) { + const auto neighbor_b = my_out_graph[k]; + if (neighbor_a == neighbor_b) { + num_dup++; + break; + } + } + } } + RAFT_EXPECTS( + num_dup == 0, "%lu duplicated node(s) are found in the generated CAGRA graph", num_dup); + RAFT_EXPECTS( + num_oor == 0, "%lu out-of-range index node(s) are found in the generated CAGRA graph", num_oor); +} + +template +void merge_graph_gpu(raft::resources const& res, + IdxT* output_graph_ptr, + const IdxT* d_rev_graph_ptr, + uint32_t* d_rev_graph_count_ptr, + IdxT* mst_graph_ptr, + uint32_t* mst_graph_num_edges_ptr, + uint64_t graph_size, + uint64_t output_graph_degree, + bool guarantee_connectivity) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/combine"); + + auto default_ws_mr = raft::resource::get_workspace_resource(res); + const double merge_graph_start = cur_time(); + + auto d_check_num_protected_edges = raft::make_device_scalar(res, true); + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + + uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + + batched_device_view_from_host d_output_graph( + res, + raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ true, + /*initialize*/ true); + + batched_device_view_from_host d_mst_graph( + res, + raft::make_host_matrix_view( + mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true); + + batched_device_view_from_host d_mst_graph_num_edges( + res, + raft::make_host_matrix_view( + mst_graph_ptr, guarantee_connectivity ? graph_size : 0, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ false, + /*initialize*/ true); + + const uint32_t num_warps = 4; + const dim3 threads_merge(raft::WarpSize * num_warps, 1, 1); + const dim3 blocks_merge(raft::ceildiv(batch_size, num_warps), 1, 1); + const size_t merge_smem_size = num_warps * output_graph_degree * sizeof(IdxT); + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + auto mst_graph_view = d_mst_graph.next_view(); + auto mst_graph_num_edges_view = d_mst_graph_num_edges.next_view(); + auto output_view = d_output_graph.next_view(); + kern_merge_graph + <<>>( + output_view.data_handle(), + d_rev_graph_ptr, + d_rev_graph_count_ptr, + static_cast(graph_size), + static_cast(output_graph_degree), + mst_graph_view.data_handle(), + static_cast(output_graph_degree), + mst_graph_num_edges_view.data_handle(), + batch_size, + i_batch, + guarantee_connectivity, + d_check_num_protected_edges.data_handle()); + } + + bool check_num_protected_edges = true; + raft::copy(&check_num_protected_edges, + d_check_num_protected_edges.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + + const auto merge_graph_end = cur_time(); + RAFT_EXPECTS(check_num_protected_edges, + "Failed to merge the MST, pruned, and reverse edge graphs. " + "Some nodes have too " + "many MST optimization edges."); + + RAFT_LOG_DEBUG("# Time for merging graphs: %.1lf ms", + (merge_graph_end - merge_graph_start) * 1000.0); +} + +template +void make_reverse_graph_gpu(raft::resources const& res, + IdxT* output_graph_ptr, + IdxT* d_rev_graph_ptr, + uint32_t* d_rev_graph_count_ptr, + uint64_t graph_size, + uint64_t output_graph_degree) +{ + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/reverse2"); + + auto d_rev_graph = + raft::make_device_vector_view(d_rev_graph_ptr, graph_size * output_graph_degree); + auto d_rev_graph_count = + raft::make_device_vector_view(d_rev_graph_count_ptr, graph_size); + + // + // Make reverse graph + // + const double time_make_start = cur_time(); + + raft::matrix::fill(res, d_rev_graph, IdxT(-1)); + raft::matrix::fill(res, d_rev_graph_count, uint32_t(0)); + + if (is_ptr_host_accessible(output_graph_ptr)) { + auto d_dest_nodes = + raft::make_device_mdarray(res, raft::make_extents(graph_size)); + + for (uint64_t k = 0; k < output_graph_degree; k++) { + RAFT_CUDA_TRY(cudaMemcpy2DAsync(d_dest_nodes.data_handle(), + sizeof(IdxT), + output_graph_ptr + k, + output_graph_degree * sizeof(IdxT), + 1 * sizeof(IdxT), + graph_size, + cudaMemcpyHostToDevice, + raft::resource::get_cuda_stream(res))); + + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + kern_make_rev_graph<<>>( + d_dest_nodes.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); + RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); + } + } else { + // output graph is fully device accessible, so we need no copy to device + dim3 threads(256, 1, 1); + dim3 blocks(1024, 1, 1); + for (uint64_t k = 0; k < output_graph_degree; k++) { + kern_make_rev_graph_k<<>>( + output_graph_ptr, + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree, + k); + } + } + + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); + + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", + (time_make_end - time_make_start) * 1000.0); } -} // namespace template +template void mst_optimization(raft::resources const& res, - raft::host_matrix_view input_graph, - raft::host_matrix_view output_graph, - raft::host_vector_view mst_graph_num_edges, + InputMatrixView input_graph, + OutputMatrixView output_graph, + VectorView mst_graph_num_edges, bool use_gpu = true) { if (use_gpu) { @@ -1092,71 +1508,130 @@ void mst_optimization(raft::resources const& res, RAFT_LOG_DEBUG("# MST optimization time: %.1lf sec", time_mst_opt_end - time_mst_opt_start); } -template -void count_2hop_detours(raft::host_matrix_view knn_graph, - raft::host_matrix_view detour_count) +// +// Prune unimportant edges based on 2-hop detour counts. +// +// The edge to be retained is determined without explicitly considering distance or angle. +// Suppose the edge is the k-th edge of some node-A to node-B (A->B). Among the edges +// originating at node-A, there are k-1 edges shorter than the edge A->B. Each of these +// k-1 edges are connected to a different k-1 nodes. Among these k-1 nodes, count the +// number of nodes with edges to node-B, which is the number of 2-hop detours for the +// edge A->B. Once the number of 2-hop detours has been counted for all edges, the +// specified number of edges are picked up for each node, starting with the edge with +// the lowest number of 2-hop detours. +// +template +void prune_graph_gpu(raft::resources const& res, + IdxT* knn_graph_ptr, + uint64_t graph_size, + uint64_t knn_graph_degree, + IdxT* output_graph_ptr, + uint64_t output_graph_degree) { - RAFT_EXPECTS(knn_graph.extent(0) == detour_count.extent(0), - "knn_graph and detour_count are expected to have the same number of rows"); - RAFT_EXPECTS(knn_graph.extent(1) == detour_count.extent(1), - "knn_graph and detour_count are expected to have the same number of cols"); - const uint64_t graph_size = knn_graph.extent(0); - const uint64_t graph_degree = knn_graph.extent(1); + raft::common::nvtx::range block_scope( + "cagra::graph::optimize/prune"); + auto default_ws_mr = raft::resource::get_workspace_resource(res); + + uint32_t batch_size = + std::min(static_cast(graph_size), static_cast(256 * 1024)); + const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; + + RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); + + const double prune_start = cur_time(); + + uint64_t num_keep __attribute__((unused)) = 0; + uint64_t num_full __attribute__((unused)) = 0; + auto dev_stats = raft::make_device_vector(res, 2); + auto host_stats = raft::make_host_vector(2); + raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); + + batched_device_view_from_host d_input_graph( + res, + raft::make_host_matrix_view(knn_graph_ptr, graph_size, knn_graph_degree), + /*batch_size*/ graph_size, + /*host_writeback*/ false, + /*initialize*/ true); + auto input_view = d_input_graph.next_view(); + + batched_device_view_from_host d_output_graph( + res, + raft::make_host_matrix_view(output_graph_ptr, graph_size, output_graph_degree), + /*batch_size*/ batch_size, + /*host_writeback*/ true, + /*initialize*/ false); + + auto d_invalid_neighbor_list = raft::make_device_scalar(res, 0u); + + for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { + auto output_view = d_output_graph.next_view(); + const uint32_t num_warps = 4; + const dim3 threads_prune(raft::WarpSize * num_warps, 1, 1); + const dim3 blocks_prune(raft::ceildiv(batch_size, num_warps), 1, 1); + const size_t prune_smem_size = num_warps * knn_graph_degree * (sizeof(IdxT) + sizeof(uint32_t)); + kern_fused_prune + <<>>( + input_view.data_handle(), + output_view.data_handle(), + graph_size, + knn_graph_degree, + output_graph_degree, + batch_size, + i_batch, + d_invalid_neighbor_list.data_handle(), + dev_stats.data_handle()); -#pragma omp parallel for - for (IdxT iA = 0; iA < graph_size; iA++) { - // Create a list of nodes, iB_candidates, that can be reached in 2-hops from node A. - auto iB_candidates = - raft::make_host_vector((graph_degree - 1) * (graph_degree - 1)); - for (uint64_t kAC = 0; kAC < graph_degree - 1; kAC++) { - IdxT iC = knn_graph(iA, kAC); - for (uint64_t kCB = 0; kCB < graph_degree - 1; kCB++) { - IdxT iB_candidate; - if (iC == iA || iC >= graph_size) { - iB_candidate = graph_size; - } else { - iB_candidate = knn_graph(iC, kCB); - if (iB_candidate == iA || iB_candidate == iC) { iB_candidate = graph_size; } - } - uint64_t idx; - if (kAC < kCB) { - idx = (kCB * kCB) + kAC; - } else { - idx = (kAC * (kAC + 1)) + kCB; - } - iB_candidates(idx) = iB_candidate; - } - } - // Count how many 2-hop detours are on each edge of node A. - for (uint64_t kAB = 0; kAB < graph_degree; kAB++) { - constexpr uint32_t max_count = 255; - uint32_t count = 0; - IdxT iB = knn_graph(iA, kAB); - if (iB == iA) { - count = max_count; - } else { - for (uint64_t idx = 0; idx < kAB * kAB; idx++) { - if (iB_candidates(idx) == iB) { count += 1; } - } - } - detour_count(iA, kAB) = std::min(count, max_count); - } + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG( + "# Pruning kNN Graph on GPUs (%.1lf %%)\r", + (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); } + raft::resource::sync_stream(res); + RAFT_LOG_DEBUG("\n"); + + uint32_t invalid_neighbor_list = 0; + raft::copy(&invalid_neighbor_list, + d_invalid_neighbor_list.data_handle(), + 1, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + RAFT_EXPECTS( + invalid_neighbor_list == 0, + "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " + "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " + "overflows occur during the norm computation between the dataset vectors."); + + raft::copy(res, host_stats.view(), raft::make_const_mdspan(dev_stats.view())); + num_keep = host_stats.data_handle()[0]; + num_full = host_stats.data_handle()[1]; + + const double prune_end = cur_time(); + RAFT_LOG_DEBUG( + "# Time for pruning on GPU: %.1lf sec, " + "avg_no_detour_edges_per_node: %.2lf/%u, " + "nodes_with_no_detour_at_all_edges: %.1lf%%", + prune_end - prune_start, + (double)num_keep / graph_size, + output_graph_degree, + (double)num_full / graph_size * 100); } -template , raft::memory_type::host>> -void optimize( - raft::resources const& res, - raft::mdspan, raft::row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph, - const bool guarantee_connectivity = true, - const bool use_gpu = true) +} // namespace + +template +void optimize(raft::resources const& res, + InputMatrixView knn_graph, + OutputMatrixView new_graph, + const bool guarantee_connectivity = true, + const bool use_gpu = true) { RAFT_LOG_DEBUG( "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); + + // large temporary memory for large arrays, e.g. everything >= O(graph_size) auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + // temporary memory for small arrays, e.g. everything <= O(batchsize * graph_degree) + auto default_ws_mr = raft::resource::get_workspace_resource(res); RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), "Each input array is expected to have the same number of rows"); @@ -1167,25 +1642,27 @@ void optimize( const uint64_t output_graph_degree = new_graph.extent(1); const uint64_t graph_size = new_graph.extent(0); // auto input_graph_ptr = knn_graph.data_handle(); - auto output_graph_ptr = new_graph.data_handle(); raft::common::nvtx::range fun_scope( "cagra::graph::optimize(%zu, %zu, %u)", graph_size, knn_graph_degree, output_graph_degree); // MST optimization - auto mst_graph = raft::make_host_matrix(0, 0); - auto mst_graph_num_edges = raft::make_host_vector(graph_size); - auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - mst_graph_num_edges_ptr[i] = 0; - } + // currently, only using GPU path for MST optimization + auto mst_graph = raft::make_host_matrix(0, 0); + auto mst_graph_num_edges = raft::make_host_vector(0); + if (guarantee_connectivity) { + auto mst_graph_num_edges = raft::make_host_vector(graph_size); + auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle(); +#pragma omp parallel for + for (uint64_t i = 0; i < graph_size; i++) { + mst_graph_num_edges_ptr[i] = 0; + } raft::common::nvtx::range block_scope( "cagra::graph::optimize/check_connectivity"); mst_graph = raft::make_host_matrix(graph_size, output_graph_degree); RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity."); - mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); + mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu); for (uint64_t i = 0; i < graph_size; i++) { if (i < 8 || i >= graph_size - 8) { @@ -1194,437 +1671,60 @@ void optimize( } } + // prune graph -- will always use GPU path { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune"); - const double time_prune_start = cur_time(); - - // - // Prune unimportant edges. - // - // The edge to be retained is determined without explicitly considering - // distance or angle. Suppose the edge is the k-th edge of some node-A to - // node-B (A->B). Among the edges originating at node-A, there are k-1 edges - // shorter than the edge A->B. Each of these k-1 edges are connected to a - // different k-1 nodes. Among these k-1 nodes, count the number of nodes with - // edges to node-B, which is the number of 2-hop detours for the edge A->B. - // Once the number of 2-hop detours has been counted for all edges, the - // specified number of edges are picked up for each node, starting with the - // edge with the lowest number of 2-hop detours. - // - auto detour_count = raft::make_host_matrix(graph_size, knn_graph_degree); - - // - // If the available device memory is insufficient, do not use the GPU to count - // the number of 2-hop detours, but use the CPU. - // - bool _use_gpu = use_gpu; - if (_use_gpu) { - try { - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - auto d_input_graph = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - } catch (std::bad_alloc& e) { - RAFT_LOG_DEBUG("Insufficient memory for 2-hop node counting on GPU"); - _use_gpu = false; - } catch (raft::logic_error& e) { - RAFT_LOG_DEBUG("Insufficient memory for 2-hop node counting on GPU (logic error)"); - _use_gpu = false; - } - } - if (_use_gpu) { - // Count 2-hop detours on GPU - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-GPU"); - const double time_2hop_count_start = cur_time(); - - uint64_t num_keep __attribute__((unused)) = 0; - uint64_t num_full __attribute__((unused)) = 0; - auto d_detour_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size, knn_graph_degree)); - - raft::matrix::fill(res, d_detour_count.view(), uint8_t(0xff)); - - auto d_num_no_detour_edges = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - raft::matrix::fill(res, d_num_no_detour_edges.view(), uint32_t(0)); - - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); - - RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - - // Copy knn_graph over to device if necessary - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view( - knn_graph.data_handle(), graph_size, knn_graph_degree)); - - constexpr int MAX_DEGREE = 1024; - if (knn_graph_degree > MAX_DEGREE) { - RAFT_FAIL( - "The degree of input knn graph is too large (%zu). " - "It must be equal to or smaller than %d.", - knn_graph_degree, - MAX_DEGREE); - } - const uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - const dim3 threads_prune(32, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); - - raft::matrix::fill(res, dev_stats.view(), uint64_t(0)); - - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - kern_prune - <<>>( - d_input_graph.data_handle(), - graph_size, - knn_graph_degree, - output_graph_degree, - batch_size, - i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), - dev_stats.data_handle()); - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG( - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); - } - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - raft::copy(res, detour_count.view(), raft::make_const_mdspan(d_detour_count.view())); - - raft::copy(res, host_stats.view(), raft::make_const_mdspan(dev_stats.view())); - num_keep = host_stats.data_handle()[0]; - num_full = host_stats.data_handle()[1]; - - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG( - "# Time for 2-hop detour counting on GPU: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%", - time_2hop_count_end - time_2hop_count_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); - } else { - // Count 2-hop detours on CPU - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/prune/2-hop-counting-by-CPU"); - const double time_2hop_count_start = cur_time(); - - count_2hop_detours(knn_graph, detour_count.view()); - - const double time_2hop_count_end = cur_time(); - RAFT_LOG_DEBUG("# Time for 2-hop detour counting on CPU: %.1lf sec", - time_2hop_count_end - time_2hop_count_start); - } - - // Create pruned kNN graph - bool invalid_neighbor_list = false; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable - // count of the neighbors while increasing the target detourable count from zero. - uint64_t pk = 0; - uint32_t num_detour = 0; - for (uint32_t l = 0; l < knn_graph_degree && pk < output_graph_degree; l++) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < knn_graph_degree; k++) { - const auto num_detour_k = detour_count(i, k); - // Find the detourable count to check in the next iteration - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } - - // Store the neighbor index if its detourable count is equal to `num_detour`. - if (num_detour_k != num_detour) { continue; } - - // Check duplication and append - const auto candidate_node = knn_graph(i, k); - bool dup = false; - for (uint32_t dk = 0; dk < pk; dk++) { - if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) { - dup = true; - break; - } - } - if (!dup && candidate_node < graph_size) { - output_graph_ptr[i * output_graph_degree + pk] = candidate_node; - pk += 1; - } - if (pk >= output_graph_degree) break; - } - if (pk >= output_graph_degree) break; - - if (next_num_detour == std::numeric_limits::max()) { - // There are no valid edges enough in the initial kNN graph. Break the loop here and catch - // the error at the next validation (pk != output_graph_degree). - break; - } - num_detour = next_num_detour; - } - if (pk != output_graph_degree) { - RAFT_LOG_DEBUG( - "Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - i); - invalid_neighbor_list = true; - } - } - RAFT_EXPECTS( - !invalid_neighbor_list, - "Could not generate an intermediate CAGRA graph because the initial kNN graph contains too " - "many invalid or duplicated neighbor nodes. This error can occur, for example, if too many " - "overflows occur during the norm computation between the dataset vectors."); - - const double time_prune_end = cur_time(); - RAFT_LOG_DEBUG("# Pruning time: %.1lf ms", (time_prune_end - time_prune_start) * 1000.0); + prune_graph_gpu(res, + knn_graph.data_handle(), + graph_size, + knn_graph_degree, + new_graph.data_handle(), + output_graph_degree); } - auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); - auto rev_graph_count = raft::make_host_vector(graph_size); + // reverse graph creation will always use the GPU + auto d_rev_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, output_graph_degree)); - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/reverse"); - // - // Make reverse graph - // - const double time_make_start = cur_time(); - - device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); - raft::matrix::fill(res, - raft::make_device_vector_view( - d_rev_graph.data_handle(), graph_size * output_graph_degree), - IdxT(-1)); - - auto d_rev_graph_count = raft::make_device_mdarray( - res, large_tmp_mr, raft::make_extents(graph_size)); - raft::matrix::fill(res, d_rev_graph_count.view(), uint32_t(0)); - - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = - raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); + // This should use the default workspace resource for random access / atomics + auto d_rev_graph_count = raft::make_device_mdarray( + res, default_ws_mr, raft::make_extents(graph_size)); - for (uint64_t k = 0; k < output_graph_degree; k++) { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; - dest_nodes(i) = output_graph_ptr[k + (output_graph_degree * i)]; - } - raft::resource::sync_stream(res); + const double time_make_start = cur_time(); - raft::copy(res, d_dest_nodes.view(), raft::make_const_mdspan(dest_nodes.view())); + make_reverse_graph_gpu(res, + new_graph.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + graph_size, + output_graph_degree); - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree); - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); - } - - raft::resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - if (d_rev_graph.allocated_memory()) { - raft::copy(res, rev_graph.view(), raft::make_const_mdspan(d_rev_graph.view())); - } - raft::copy(res, rev_graph_count.view(), raft::make_const_mdspan(d_rev_graph_count.view())); - - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", - (time_make_end - time_make_start) * 1000.0); - } + const double time_make_end = cur_time(); + RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf ms", + (time_make_end - time_make_start) * 1000.0); + // merge graph -- will always use GPU path { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/combine"); - // - // Create search graphs from MST and pruned and reverse graphs - // - const double time_replace_start = cur_time(); - - bool check_num_protected_edges = true; -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - auto my_rev_graph = rev_graph.data_handle() + (output_graph_degree * i); - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); - - // If guarantee_connectivity == true, use a temporal list to merge the neighbor lists of the - // graphs. - std::vector temp_output_neighbor_list; - if (guarantee_connectivity) { - temp_output_neighbor_list.resize(output_graph_degree); - my_out_graph = temp_output_neighbor_list.data(); - const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i]; - - // Set MST graph edges - for (uint32_t j = 0; j < mst_graph_num_edges; j++) { - my_out_graph[j] = mst_graph(i, j); - } - - // Set pruned graph edges - for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges; - (pruned_j < output_graph_degree) && (output_j < output_graph_degree); - pruned_j++) { - const auto v = output_graph_ptr[output_graph_degree * i + pruned_j]; - - // duplication check - bool dup = false; - for (uint32_t m = 0; m < output_j; m++) { - if (v == my_out_graph[m]) { - dup = true; - break; - } - } - - if (!dup) { - my_out_graph[output_j] = v; - output_j++; - } - } - } - - const auto num_protected_edges = - std::max(mst_graph_num_edges_ptr[i], output_graph_degree / 2); - if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; } - if (num_protected_edges == output_graph_degree) continue; - - // Replace some edges of the output graph with edges of the reverse graph. - auto kr = std::min(rev_graph_count.data_handle()[i], output_graph_degree); - while (kr) { - kr -= 1; - if (my_rev_graph[kr] < graph_size) { - uint64_t pos = pos_in_array(my_rev_graph[kr], my_out_graph, output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos >= output_graph_degree) { - num_shift = output_graph_degree - num_protected_edges - 1; - } - shift_array(my_out_graph + num_protected_edges, num_shift); - my_out_graph[num_protected_edges] = my_rev_graph[kr]; - } - } - - // If guarantee_connectivity == true, move the output neighbor list from the temporal list to - // the output list. If false, the copy is not needed because my_out_graph is a pointer to the - // output buffer. - if (guarantee_connectivity) { - for (uint32_t j = 0; j < output_graph_degree; j++) { - output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j]; - } - } - } - RAFT_EXPECTS(check_num_protected_edges, - "Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too " - "many MST optimization edges."); - - const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf ms", - (time_replace_end - time_replace_start) * 1000.0); - - /* stats */ - uint64_t num_replaced_edges = 0; -#pragma omp parallel for reduction(+ : num_replaced_edges) - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - const uint64_t pos = - pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); - if (pos == output_graph_degree) { num_replaced_edges += 1; } - } - } - RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); + merge_graph_gpu(res, + new_graph.data_handle(), + d_rev_graph.data_handle(), + d_rev_graph_count.data_handle(), + mst_graph.data_handle(), + mst_graph_num_edges.data_handle(), + graph_size, + output_graph_degree, + guarantee_connectivity); } - // Check number of incoming edges - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/check_edges"); - auto in_edge_count = raft::make_host_vector(graph_size); - auto in_edge_count_ptr = in_edge_count.data_handle(); -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - in_edge_count_ptr[i] = 0; - } -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - if (j >= graph_size) continue; -#pragma omp atomic - in_edge_count_ptr[j] += 1; - } - } - auto hist = raft::make_host_vector(output_graph_degree); - auto hist_ptr = hist.data_handle(); - for (uint64_t k = 0; k < output_graph_degree; k++) { - hist_ptr[k] = 0; - } -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - uint32_t count = in_edge_count_ptr[i]; - if (count >= output_graph_degree) continue; -#pragma omp atomic - hist_ptr[count] += 1; - } - RAFT_LOG_DEBUG("# Histogram for number of incoming edges\n"); - uint32_t sum_hist = 0; - for (uint64_t k = 0; k < output_graph_degree; k++) { - sum_hist += hist_ptr[k]; - RAFT_LOG_DEBUG("# %3lu, %8u, %lf, (%8u, %lf)\n", - k, - hist_ptr[k], - (double)hist_ptr[k] / graph_size, - sum_hist, - (double)sum_hist / graph_size); - } - } + raft::resource::sync_stream(res); - // Check duplication and out-of-range indices - { - raft::common::nvtx::range block_scope( - "cagra::graph::optimize/check_duplicates"); - uint64_t num_dup = 0; - uint64_t num_oor = 0; -#pragma omp parallel for reduction(+ : num_dup) reduction(+ : num_oor) - for (uint64_t i = 0; i < graph_size; i++) { - auto my_out_graph = output_graph_ptr + (output_graph_degree * i); - for (uint32_t j = 0; j < output_graph_degree; j++) { - const auto neighbor_a = my_out_graph[j]; - - // Check oor - if (neighbor_a > graph_size) { - num_oor++; - continue; - } + if (is_ptr_host_accessible(new_graph.data_handle())) { + // following checks require host access + log_incoming_edges_histogram(new_graph.data_handle(), graph_size, output_graph_degree); - // Check duplication - for (uint32_t k = j + 1; k < output_graph_degree; k++) { - const auto neighbor_b = my_out_graph[k]; - if (neighbor_a == neighbor_b) { num_dup++; } - } - } - } - RAFT_EXPECTS( - num_dup == 0, "%lu duplicated node(s) are found in the generated CAGRA graph", num_dup); - RAFT_EXPECTS(num_oor == 0, - "%lu out-of-range index node(s) are found in the generated CAGRA graph", - num_oor); + check_duplicates_and_out_of_range( + new_graph.data_handle(), graph_size, output_graph_degree); + } else { + RAFT_LOG_DEBUG("Output graph is on GPU, skipping checks"); } } diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index b1dadb1e36..79d1ed1cae 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -9,9 +9,14 @@ #include #include #include +#include +#include +#include +#include +#include #include #include - +#include #include #include @@ -19,6 +24,7 @@ #include #include +#include #include namespace cuvs::neighbors::cagra::detail { @@ -154,6 +160,22 @@ struct gen_index_msb_1_mask { }; } // namespace utils +template +bool is_ptr_device_accessible(T* ptr) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + return attr.devicePointer != nullptr; +} + +template +bool is_ptr_host_accessible(T* ptr) +{ + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, ptr)); + return attr.hostPointer != nullptr; +} + /** * Utility to sync memory from a host_matrix_view to a device_matrix_view * @@ -292,4 +314,395 @@ void copy_with_padding( raft::resource::get_cuda_stream(res))); } } + +/** + * Utility to create a batched device view from a host view + * + * This utility will create a batched device view from a host view and will handle the prefetch and + * writeback of the data Each batch can be referenced exactlyonce by calling the next_view() + * function + * + * Usage: + * ``` + * batched_device_view_from_host view(res, host_view, batch_size, host_writeback, + * initialize); while (view.next_view().extent(0) > 0) { auto device_view = view.next_view(); + * // use device_view + * } + * ``` + * + * The call to next_view() will + * * synchronize on all previous operations / increments batch_id_ + * * (optionally) write back the data of the previous batch to the host + * * (optionally) prefetch the data of the next batch + * * return the view of the current batch + * + * @tparam T The type of the data + * @tparam IdxT The type of the index + */ +template +class batched_device_view_from_host { + public: + enum class memory_strategy { + device_only, // data is on device only (no copy needed) + copy_device, // data is explicitly moved to/from device buffers + managed_only, // data is on managed memory (system managed) + }; + + /** + * Create a batched device view from a host view and will handle the prefetch and + * writeback of the data. Each batch can be referenced exactly once by calling the next_view() + * method. + * + * @param res The resources to use + * @param host_view The host view to create the batched device view from + * @param batch_size The batch size + * @param host_writeback Whether to write back the data to the host (only for host memory) + * (default: false) + * @param initialize Whether to initialize the data (only for managed memory) (default: true) + */ + batched_device_view_from_host(raft::resources const& res, + raft::host_matrix_view host_view, + uint64_t batch_size, + bool host_writeback = false, + bool initialize = true) + : res_(res), + host_view_(host_view), + batch_size_(batch_size), + offset_(0), + batch_id_(-2), + num_buffers_(2), + host_writeback_(host_writeback), + initialize_(initialize) + { + if (host_view.extent(0) == 0) { + mem_strategy_ = memory_strategy::device_only; + return; + } + + RAFT_EXPECTS(host_writeback_ || initialize_, + "At least one of host_writeback or initialize must be true"); + + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr_, host_view.data_handle())); + switch (attr_.type) { + case cudaMemoryTypeUnregistered: + case cudaMemoryTypeHost: + case cudaMemoryTypeManaged: mem_strategy_ = memory_strategy::copy_device; break; + case cudaMemoryTypeDevice: mem_strategy_ = memory_strategy::device_only; break; + } + + RAFT_LOG_DEBUG("Memory strategy: %d for type %d, size %zu", + static_cast(mem_strategy_), + static_cast(attr_.type), + host_view.extent(0) * host_view.extent(1) * sizeof(T)); + + // buffer allocations + if (mem_strategy_ == memory_strategy::copy_device) { + try { + device_mem_[0].emplace(raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[0] = device_mem_[0]->data_handle(); + if (batch_size < static_cast(host_view.extent(0))) { + device_mem_[1].emplace(raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[1] = device_mem_[1]->data_handle(); + } + if (host_writeback_ && initialize_ && + batch_size * 2 < static_cast(host_view.extent(0))) { + num_buffers_ = 3; + device_mem_[2].emplace(raft::make_device_mdarray( + res, + raft::resource::get_workspace_resource(res), + raft::make_extents(batch_size, host_view.extent(1)))); + device_ptr[2] = device_mem_[2]->data_handle(); + } + } catch (std::bad_alloc& e) { + if (attr_.devicePointer != nullptr) { + RAFT_LOG_DEBUG("Insufficient memory for device buffers, switching to managed memory"); + mem_strategy_ = memory_strategy::managed_only; + } else { + throw std::bad_alloc(); + } + } catch (raft::logic_error& e) { + if (attr_.devicePointer != nullptr) { + RAFT_LOG_DEBUG( + "Insufficient memory for device buffers (logic error), switching to managed memory"); + mem_strategy_ = memory_strategy::managed_only; + } else { + throw raft::logic_error("Insufficient memory for device buffers (logic error)"); + } + } + } + + // setup stream pool if not already present + size_t required_streams = host_writeback_ && initialize_ ? 2 : 1; + if (!res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) || + raft::resource::get_stream_pool_size(res) < required_streams) { + // always create at least 2 streams to account for subsequent iterator calls + raft::resource::set_cuda_stream_pool(res, std::make_shared(2)); + } + prefetch_stream_ = raft::resource::get_stream_from_stream_pool(res); + writeback_stream_ = raft::resource::get_stream_from_stream_pool(res); + + // if data is managed and not for_write_ we can set the attribute on the device ptr + if (mem_strategy_ == memory_strategy::managed_only) { + location_.type = cudaMemLocationTypeDevice; + location_.id = static_cast(raft::resource::get_device_id(res_)); + if (!host_writeback_) { + advise_read_mostly(host_view_.data_handle(), + host_view_.extent(0) * host_view_.extent(1) * sizeof(T)); + // TODO maybe also reset upon destruction + } + } + + // prefetch next batch (0) + prefetch_next_batch(); + } + + ~batched_device_view_from_host() noexcept + { + raft::resource::sync_stream(res_); + + // if data is on host and for_write --> make sure to copy back last active + // if data is managed and evict --> evict last active + + // make sure to sync on prefetch stream & res + switch (mem_strategy_) { + case memory_strategy::managed_only: + if (!host_writeback_) { + uint32_t discard_pos = batch_id_ % num_buffers_; + size_t discard_size_rows = actual_batch_size_[discard_pos]; + if (batch_id_ > 0) { + discard_pos = (batch_id_ - 1) % num_buffers_; + discard_size_rows += batch_size_; + } + discard_managed_region(device_ptr[discard_pos], + discard_size_rows * host_view_.extent(1) * sizeof(T)); + writeback_stream_.synchronize(); + } + break; + case memory_strategy::copy_device: + if (host_writeback_) { + uint32_t writeback_pos_last = batch_id_ % num_buffers_; + if (batch_id_ > 0) { + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); + } + { + uint64_t writeback_offset_last = batch_id_ * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos_last], + writeback_offset_last, + actual_batch_size_[writeback_pos_last]); + } + writeback_stream_.synchronize(); + } + break; + case memory_strategy::device_only: break; + } + } + + /** + * Returns the next view of the batch + * + * This function will ensure the next batch is ready and will trigger the prefetch of the + * subsequent next batch. If writeback is enabled, the last active batch will be written back to + * the host. + * + * @return The next view of the batch + */ + raft::device_matrix_view next_view() + { + bool end_of_data = static_cast((batch_id_ + 1) * batch_size_) >= + static_cast(host_view_.extent(0)); + + // special case for empty host view or last batch surpassed + if (end_of_data) { + return raft::make_device_matrix_view(nullptr, 0, host_view_.extent(1)); + } + + // trigger prefetch of next batch (also increments batch_id_) + prefetch_next_batch(); + + uint32_t current_pos = batch_id_ % num_buffers_; + return raft::make_device_matrix_view( + device_ptr[current_pos], actual_batch_size_[current_pos], host_view_.extent(1)); + } + + private: + /** + * Prefetch the next batch + * + * This function will prefetch the next batch and will handle the writeback of the data. + * + * @return True if the next batch exists, false otherwise + */ + bool prefetch_next_batch() + { + batch_id_++; + + // ensure previous batch at position batch_id_ is ready + if (initialize_) { prefetch_stream_.synchronize(); } + if (host_writeback_) { writeback_stream_.synchronize(); } + + // this step will + // * write back data from batch_id_ - 1 + // * prefetch data for batch_id_ + 1 + + // if data is on host and host_writeback_ is true we will have to copy it back + // if data is on host and initialize_ is true we will have to copy it to the device_ptr + + // if data is managed and !host_writeback_ we can discard the data from device memory + // if data is managed and initialize_ is true we can prefetch it to the device + // if data is managed and !initialize_ we can discard and prefetch the data location + + // if data is on device only this is almost a noop, just prepping the pointers + + RAFT_EXPECTS(static_cast(offset_) <= host_view_.extent(0), "Offset out of bounds"); + + bool next_batch_exists = offset_ < static_cast(host_view_.extent(0)); + + if (next_batch_exists) { + // synchronize to ensure all previous operations are completed + // in particular all work on batch_id_ - 1 + raft::resource::sync_stream(res_); + + int32_t prefetch_pos = (batch_id_ + 1) % num_buffers_; + actual_batch_size_[prefetch_pos] = min(batch_size_, host_view_.extent(0) - offset_); + + switch (mem_strategy_) { + case memory_strategy::managed_only: + if (!host_writeback_ && batch_id_ > 1) { + uint32_t discard_pos = (batch_id_ - 1) % num_buffers_; + size_t discard_size = batch_size_ * host_view_.extent(1) * sizeof(T); + discard_managed_region(device_ptr[discard_pos], discard_size); + } + // prefetch next position + device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); + prefetch_managed_region( + device_ptr[prefetch_pos], + actual_batch_size_[prefetch_pos] * host_view_.extent(1) * sizeof(T)); + break; + case memory_strategy::copy_device: + if (host_writeback_ && batch_id_ > 0) { + // copy back last active + uint32_t writeback_pos = (batch_id_ - 1) % num_buffers_; + uint64_t writeback_offset = (batch_id_ - 1) * batch_size_; + writeback_from_device_to_host(device_ptr[writeback_pos], writeback_offset, batch_size_); + } + if (initialize_) { + // prefetch next position + prefetch_from_host_to_device( + device_ptr[prefetch_pos], offset_, actual_batch_size_[prefetch_pos]); + } + + break; + case memory_strategy::device_only: + // just move pointer to next position + device_ptr[prefetch_pos] = host_view_.data_handle() + offset_ * host_view_.extent(1); + break; + } + + offset_ += actual_batch_size_[prefetch_pos]; + } + + return next_batch_exists; + } + + void advise_read_mostly(T* ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + RAFT_CUDA_TRY(cudaMemAdvise(ptr, size, cudaMemAdviseSetReadMostly, location_)); +#else + RAFT_CUDA_TRY(cudaMemAdvise_v2(ptr, size, cudaMemAdviseSetReadMostly, location_)); +#endif + } + + void discard_managed_region(T* dev_ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + void* dptrs[1] = {dev_ptr}; + size_t sizes[1] = {size}; + RAFT_CUDA_TRY(cudaMemDiscardBatchAsync(dptrs, sizes, 1, 0, writeback_stream_)); +#endif + // FIXME: CUDA12 does not support discard + } + + void prefetch_managed_region(T* dev_ptr, size_t size) + { +#if CUDA_VERSION >= 13000 + if (initialize_) { + RAFT_CUDA_TRY(cudaMemPrefetchAsync(dev_ptr, size, location_, 0, prefetch_stream_)); + } else { + void* dptrs[1] = {dev_ptr}; + size_t sizes[1] = {size}; + RAFT_CUDA_TRY( + cudaMemDiscardAndPrefetchBatchAsync(dptrs, sizes, 1, location_, 0, prefetch_stream_)); + } +#else + // FIXME: CUDA12 does not support discard - so we just prefetch + if (initialize_) { + RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2(dev_ptr, size, location_, 0, prefetch_stream_)); + } else { + RAFT_CUDA_TRY(cudaMemPrefetchAsync_v2(dev_ptr, size, location_, 0, prefetch_stream_)); + } +#endif + } + + void prefetch_from_host_to_device(T* dev_ptr, size_t src_row_offset, size_t num_rows) + { + const size_t n_elem = num_rows * host_view_.extent(1); + const size_t n_bytes = n_elem * sizeof(T); + // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory + RAFT_CUDA_TRY(cudaMemcpyAsync(dev_ptr, + host_view_.data_handle() + src_row_offset * host_view_.extent(1), + n_bytes, + cudaMemcpyHostToDevice, + prefetch_stream_)); + } + + void writeback_from_device_to_host(T* dev_ptr, size_t dst_row_offset, size_t num_rows) + { + const size_t n_elem = num_rows * host_view_.extent(1); + const size_t n_bytes = n_elem * sizeof(T); + // use memcpy instead of raft::copy to avoid strange behavior with HMM/ATS memory + RAFT_CUDA_TRY(cudaMemcpyAsync(host_view_.data_handle() + dst_row_offset * host_view_.extent(1), + dev_ptr, + n_bytes, + cudaMemcpyDeviceToHost, + writeback_stream_)); + } + + // stream pool for local streams + std::optional> local_stream_pool_; + rmm::cuda_stream_view prefetch_stream_; + rmm::cuda_stream_view writeback_stream_; + + // configuration + memory_strategy mem_strategy_; + const raft::resources& res_; + bool initialize_; // initialize the data on the device + bool host_writeback_; // write back the data to the host + + // batch position information + uint64_t batch_size_; + int32_t batch_id_; + uint64_t offset_; + + cudaMemLocation location_; + + // input pointer information + raft::host_matrix_view host_view_; + cudaPointerAttributes attr_; + + // internal device buffers + uint64_t num_buffers_; + std::optional> device_mem_[3]; + T* device_ptr[3]; + uint32_t actual_batch_size_[3]; +}; + } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index bbddef87e5..2138cb070b 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -175,6 +175,7 @@ ConfigureTest( ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_HELPERS_TEST PATH neighbors/ann_cagra/test_optimize_uint32_t.cu + neighbors/ann_cagra/test_batched_device_view_from_host.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu b/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu new file mode 100644 index 0000000000..1e1cc13093 --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra/test_batched_device_view_from_host.cu @@ -0,0 +1,205 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../../src/neighbors/detail/cagra/utils.hpp" + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra { + +using IdxT = uint32_t; + +struct BatchConfig { + bool initialize; + bool host_writeback; +}; + +struct DimsConfig { + int64_t n_rows; + int64_t n_cols; + uint64_t batch_size; +}; + +class BatchedDeviceViewFromHostTest : public ::testing::Test { + protected: + void SetUp() override { raft::resource::sync_stream(res); } + + /** + * Run batched_device_view_from_host over host data, copy device views back, + * and verify against the input. + */ + template + void run_and_verify_batched(InputMatrixView input_view, + uint64_t batch_size, + bool host_writeback, + bool initialize) + { + int64_t n_rows = input_view.extent(0); + int64_t n_cols = input_view.extent(1); + + std::vector readback(n_rows * n_cols); + + int64_t total_processed = 0; + + { + cagra::detail::batched_device_view_from_host batched( + res, + raft::make_host_matrix_view(input_view.data_handle(), n_rows, n_cols), + batch_size, + host_writeback, + initialize); + while (true) { + auto dev_view = batched.next_view(); + if (dev_view.extent(0) == 0) break; + + if (initialize) { + raft::copy(readback.data() + total_processed * n_cols, + dev_view.data_handle(), + dev_view.extent(0) * dev_view.extent(1), + raft::resource::get_cuda_stream(res)); + } + if (host_writeback) { raft::matrix::fill(res, dev_view, IdxT(17)); } + total_processed += dev_view.extent(0); + } + } + raft::resource::sync_stream(res); + + EXPECT_EQ(total_processed, n_rows); + if (initialize) { + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(13)) << "Mismatch (initialize) at index " << i; + } + } + if (host_writeback) { + auto readback_view = + raft::make_host_matrix_view(readback.data(), n_rows, n_cols); + raft::copy(res, readback_view, input_view); + raft::resource::sync_stream(res); + for (int64_t i = 0; i < n_rows * n_cols; ++i) { + EXPECT_EQ(readback[i], IdxT(17)) << "Mismatch (host_writeback) at index " << i; + } + } + } + + raft::resources res; +}; + +TEST_F(BatchedDeviceViewFromHostTest, EmptyView) +{ + auto host_empty = raft::make_host_matrix(0, 8); + auto host_view = host_empty.view(); + cagra::detail::batched_device_view_from_host batched( + res, host_view, /*batch_size=*/128, /*host_writeback=*/false, /*initialize=*/true); + + auto view = batched.next_view(); + EXPECT_EQ(view.extent(0), 0); + EXPECT_EQ(view.extent(1), 8); + EXPECT_EQ(view.data_handle(), nullptr); +} + +using BatchDimsParam = std::tuple; + +class BatchedDeviceViewFromHostParameterizedTest + : public BatchedDeviceViewFromHostTest, + public ::testing::WithParamInterface {}; + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, VectorHostData) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + std::vector host_data(n_rows * n_cols); + auto host_view = raft::make_host_matrix_view(host_data.data(), n_rows, n_cols); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, PinnedMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_pinned_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, ManagedMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_managed_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + std::fill(host_view.data_handle(), host_view.data_handle() + n_rows * n_cols, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +TEST_P(BatchedDeviceViewFromHostParameterizedTest, DeviceMemory) +{ + auto [batch_config, dims_config] = GetParam(); + auto [initialize, host_writeback] = batch_config; + auto [n_rows, n_cols, batch_size] = dims_config; + + auto host_matrix = raft::make_device_matrix(res, n_rows, n_cols); + auto host_view = host_matrix.view(); + + raft::matrix::fill(res, host_view, IdxT(13)); + + run_and_verify_batched(host_view, batch_size, host_writeback, initialize); +} + +static const std::array kBatchConfigs = {{ + {/*initialize=*/true, /*host_writeback=*/false}, + {/*initialize=*/false, /*host_writeback=*/true}, + {/*initialize=*/true, /*host_writeback=*/true}, +}}; + +static const std::array kDimsConfigs = {{ + {/*n_rows=*/64, /*n_cols=*/32, /*batch_size=*/256}, // rows less than batch size, single batch + {/*n_rows=*/64, /*n_cols=*/32, /*batch_size=*/64}, // single batch + {/*n_rows=*/256, /*n_cols=*/32, /*batch_size=*/32}, // multiple batches + {/*n_rows=*/500, + /*n_cols=*/32, + /*batch_size=*/128}, // multiple batches, partial batch in the end +}}; + +INSTANTIATE_TEST_SUITE_P(BatchConfigs, + BatchedDeviceViewFromHostParameterizedTest, + ::testing::Combine(::testing::ValuesIn(kBatchConfigs), + ::testing::ValuesIn(kDimsConfigs))); + +} // namespace cuvs::neighbors::cagra