diff --git a/cpp/src/neighbors/scann/detail/scann_avq.cuh b/cpp/src/neighbors/scann/detail/scann_avq.cuh index e7c1663f3e..6c3bb045e4 100644 --- a/cpp/src/neighbors/scann/detail/scann_avq.cuh +++ b/cpp/src/neighbors/scann/detail/scann_avq.cuh @@ -511,11 +511,16 @@ class cluster_loader { raft::make_device_matrix_view(d_cluster_copy_buf_.data_handle(), size, dim_); if (needs_copy_) { + // For prefetching to overlap with other gpu work + // we need to schedule copies on the provided copy stream stream_ + auto copy_res = raft::resources(res); + raft::resource::set_cuda_stream(copy_res, stream_); + // htod auto h_cluster_ids = raft::make_pinned_vector_view(cluster_ids_buf_.data_handle(), size); - raft::copy(res, h_cluster_ids, cluster_ids); + raft::copy(copy_res, h_cluster_ids, cluster_ids); raft::resource::sync_stream(res, stream_); auto pinned_cluster = raft::make_pinned_matrix_view( @@ -529,9 +534,8 @@ class cluster_loader { sizeof(T) * dim_); } - raft::copy(res, cluster_vectors, raft::make_const_mdspan(pinned_cluster)); + raft::copy(copy_res, cluster_vectors, raft::make_const_mdspan(pinned_cluster)); raft::resource::sync_stream(res, stream_); - } else { // dtod auto dataset_view =