From 23059eaa6bab9144becdad2cddd23f9846f8bd5a Mon Sep 17 00:00:00 2001 From: Max Buckley Date: Tue, 26 May 2026 17:52:30 +0200 Subject: [PATCH] perf(brute_force): skip csr_to_coo + rows alloc on inner-product filtered search The CSR-path of filtered brute-force builds a per-nonzero `rows` array via `raft::sparse::convert::csr_to_coo`. That array is only consumed by `epilogue_on_csr`, which itself only runs for L2 / L2Sqrt / Cosine to combine the masked inner products with the precomputed norms. For InnerProduct the epilogue is skipped, so the allocation and kernel launch were dead work on every IP-metric filtered search. Move the `rmm::device_uvector rows(...)` allocation and the `csr_to_coo` call inside the L2/Cosine branch alongside their only consumer. No behavior change for L2/L2Sqrt/Cosine. For InnerProduct filtered searches that hit the sparse CSR path (sparsity >= 0.9) this saves one device allocation of `nnz * sizeof(IdxT)` and one kernel launch + write pass per call. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Max Buckley --- cpp/src/neighbors/detail/knn_brute_force.cuh | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index 34ef8dd937..3bf65883dd 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -671,13 +671,7 @@ void brute_force_search_filtered( // create filter csr view auto compressed_csr_view = csr.structure_view(); - rmm::device_uvector rows(compressed_csr_view.get_nnz(), stream); - raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(), - compressed_csr_view.get_n_rows(), - rows.data(), - compressed_csr_view.get_nnz(), - stream); - auto dataset_view = raft::make_device_matrix_view( + auto dataset_view = raft::make_device_matrix_view( idx.dataset().data_handle(), n_dataset, dim); auto csr_view = raft::make_device_csr_matrix_view( @@ -714,6 +708,14 @@ void brute_force_search_filtered( query_norms_->view()); } } + // rows array (COO row indices) is only needed by the L2/Cosine epilogue, + // so build it here rather than unconditionally above. + rmm::device_uvector rows(compressed_csr_view.get_nnz(), stream); + raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(), + compressed_csr_view.get_n_rows(), + rows.data(), + compressed_csr_view.get_nnz(), + stream); cuvs::neighbors::detail::epilogue_on_csr( res, csr.get_elements().data(),