From d97099b84fc7c3c53c27882aab88e051a522fb81 Mon Sep 17 00:00:00 2001 From: Robert Maschal Date: Mon, 9 Mar 2026 13:08:48 -0700 Subject: [PATCH 1/2] ScaNN: Fix AVQ prefetch --- cpp/src/neighbors/scann/detail/scann_avq.cuh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cpp/src/neighbors/scann/detail/scann_avq.cuh b/cpp/src/neighbors/scann/detail/scann_avq.cuh index e7c1663f3e..4ef226ab86 100644 --- a/cpp/src/neighbors/scann/detail/scann_avq.cuh +++ b/cpp/src/neighbors/scann/detail/scann_avq.cuh @@ -511,6 +511,11 @@ class cluster_loader { raft::make_device_matrix_view(d_cluster_copy_buf_.data_handle(), size, dim_); if (needs_copy_) { + // For prefetching to overlap with other gpu work + // we need to schedule copies on the provided copy stream stream_ + auto stream = raft::resource::get_cuda_stream(res); + raft::resource::set_cuda_stream(res, stream_); + // htod auto h_cluster_ids = raft::make_pinned_vector_view(cluster_ids_buf_.data_handle(), size); @@ -532,6 +537,8 @@ class cluster_loader { raft::copy(res, cluster_vectors, raft::make_const_mdspan(pinned_cluster)); raft::resource::sync_stream(res, stream_); + // reset stream back to previous value + raft::resource::set_cuda_stream(res, stream); } else { // dtod auto dataset_view = From 22f74e7b18edbbd83e8c12dbfeba1a9dc18f7488 Mon Sep 17 00:00:00 2001 From: Robert Maschal Date: Tue, 10 Mar 2026 10:25:35 -0700 Subject: [PATCH 2/2] Copy resources instead of setting stream --- cpp/src/neighbors/scann/detail/scann_avq.cuh | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/cpp/src/neighbors/scann/detail/scann_avq.cuh b/cpp/src/neighbors/scann/detail/scann_avq.cuh index 4ef226ab86..6c3bb045e4 100644 --- a/cpp/src/neighbors/scann/detail/scann_avq.cuh +++ b/cpp/src/neighbors/scann/detail/scann_avq.cuh @@ -513,14 +513,14 @@ class cluster_loader { if (needs_copy_) { // For prefetching to overlap with other gpu work // we need to schedule copies on the provided copy stream stream_ - auto stream = raft::resource::get_cuda_stream(res); - raft::resource::set_cuda_stream(res, stream_); + auto copy_res = raft::resources(res); + raft::resource::set_cuda_stream(copy_res, stream_); // htod auto h_cluster_ids = raft::make_pinned_vector_view(cluster_ids_buf_.data_handle(), size); - raft::copy(res, h_cluster_ids, cluster_ids); + raft::copy(copy_res, h_cluster_ids, cluster_ids); raft::resource::sync_stream(res, stream_); auto pinned_cluster = raft::make_pinned_matrix_view( @@ -534,11 +534,8 @@ class cluster_loader { sizeof(T) * dim_); } - raft::copy(res, cluster_vectors, raft::make_const_mdspan(pinned_cluster)); + raft::copy(copy_res, cluster_vectors, raft::make_const_mdspan(pinned_cluster)); raft::resource::sync_stream(res, stream_); - - // reset stream back to previous value - raft::resource::set_cuda_stream(res, stream); } else { // dtod auto dataset_view =