diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index 34ef8dd937..3bf65883dd 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -671,13 +671,7 @@ void brute_force_search_filtered( // create filter csr view auto compressed_csr_view = csr.structure_view(); - rmm::device_uvector rows(compressed_csr_view.get_nnz(), stream); - raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(), - compressed_csr_view.get_n_rows(), - rows.data(), - compressed_csr_view.get_nnz(), - stream); - auto dataset_view = raft::make_device_matrix_view( + auto dataset_view = raft::make_device_matrix_view( idx.dataset().data_handle(), n_dataset, dim); auto csr_view = raft::make_device_csr_matrix_view( @@ -714,6 +708,14 @@ void brute_force_search_filtered( query_norms_->view()); } } + // rows array (COO row indices) is only needed by the L2/Cosine epilogue, + // so build it here rather than unconditionally above. + rmm::device_uvector rows(compressed_csr_view.get_nnz(), stream); + raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(), + compressed_csr_view.get_n_rows(), + rows.data(), + compressed_csr_view.get_nnz(), + stream); cuvs::neighbors::detail::epilogue_on_csr( res, csr.get_elements().data(),