From 71decfcd94084ea0befebea78be1b41fb6879ef9 Mon Sep 17 00:00:00 2001 From: Max Buckley Date: Sat, 23 May 2026 14:56:21 +0200 Subject: [PATCH] Promote filtering_rate to base search_params; honor it in brute_force MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `filtering_rate` (a hint of the fraction of items filtered out by the sample filter) previously lived on `cagra::search_params` only. CAGRA used it to size `itopk_size` and, when the user left it at its default of `-1.0`, called `bitset_view::count(res)` on every search — a GPU popcount reduction + host sync that adds measurable latency. `brute_force` filtered search also called `bitset_view::count(res)` on every search to compute `sparsity` (used to choose between the dense tiled GEMM path and the sparse CSR path), but had no user-facing knob to skip the auto-detection. This change: - Moves `float filtering_rate = -1.0` from `cagra::search_params` to the base `cuvs::neighbors::search_params`. `cagra::search_params` inherits it; existing code that accesses `params.filtering_rate` is unaffected. - Plumbs `brute_force::search_params` through `detail::search` and `brute_force_search_filtered`. When `params.filtering_rate >= 0`, the hint is used directly as `sparsity` and the per-search popcount is skipped on the dense path. The CSR path still needs an exact non-zero count to size the matrix, so popcount runs lazily there. - Algorithms that don't use the hint (`ivf_flat`, `ivf_pq`, `hnsw`) ignore the new base field; their `search_params` inherit it but nothing reads it. Tests: extends `brute_force_prefiltered.cu` with `ResultWithFilteringRateHint` cases for both bitmap and bitset fixtures (float and half), reusing the existing parameter matrix and asserting that results match the auto-detect path. Closes #1960. Co-Authored-By: Claude Opus 4.7 (1M context) --- cpp/include/cuvs/neighbors/cagra.hpp | 9 +- cpp/include/cuvs/neighbors/common.hpp | 18 ++- cpp/src/neighbors/brute_force.cu | 8 +- cpp/src/neighbors/detail/knn_brute_force.cuh | 43 ++++++-- .../neighbors/brute_force_prefiltered.cu | 104 ++++++++++++++++++ 5 files changed, 163 insertions(+), 19 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 8edbcab8fa..a263cb5d76 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -343,12 +343,9 @@ struct search_params : cuvs::neighbors::search_params { */ float persistent_device_usage = 1.0; - /** - * A parameter indicating the rate of nodes to be filtered-out, when filtering is used. - * The value must be equal to or greater than 0.0 and less than 1.0. Default value is - * negative, in which case the filtering rate is automatically calculated. - */ - float filtering_rate = -1.0; + // `filtering_rate` is inherited from `cuvs::neighbors::search_params`. CAGRA uses it to + // size `itopk_size`; supplying a non-negative value avoids a per-search popcount + host sync + // on the `bitset_filter` path. }; /** diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 887593c23b..f8da58215f 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -116,7 +116,23 @@ struct index_params { float metric_arg = 2.0f; }; -struct search_params {}; +struct search_params { + /** + * A hint indicating the rate at which the sample filter is expected to filter out items + * (i.e. `(n_dataset - n_set_bits) / n_dataset` for a `bitset_filter`). + * + * Algorithms that benefit from knowing the filter's selectivity may use this hint to tune + * internal parameters or skip a hidden popcount kernel + host sync that would otherwise be + * required to derive it from the filter on every search call. + * + * - Negative (default): the algorithm auto-detects the rate from the filter when needed. + * This launches a GPU popcount reduction and synchronizes the stream per search call. + * - In `[0.0, 1.0)`: the algorithm trusts the supplied value and skips the auto-detection. + * + * Algorithms that do not use this hint (e.g. `ivf_flat`, `ivf_pq`) ignore it. + */ + float filtering_rate = -1.0; +}; /** * @brief Strategy for merging indices. diff --git a/cpp/src/neighbors/brute_force.cu b/cpp/src/neighbors/brute_force.cu index 2f9000acf7..c4d9b002f3 100644 --- a/cpp/src/neighbors/brute_force.cu +++ b/cpp/src/neighbors/brute_force.cu @@ -199,7 +199,7 @@ void index::update_dataset( const cuvs::neighbors::filtering::base_filter& sample_filter) \ { \ detail::search( \ - res, idx, queries, neighbors, distances, sample_filter); \ + res, params, idx, queries, neighbors, distances, sample_filter); \ } \ void search(raft::resources const& res, \ const cuvs::neighbors::brute_force::index& idx, \ @@ -209,7 +209,7 @@ void index::update_dataset( const cuvs::neighbors::filtering::base_filter& sample_filter) \ { \ detail::search( \ - res, idx, queries, neighbors, distances, sample_filter); \ + res, {}, idx, queries, neighbors, distances, sample_filter); \ } \ void search(raft::resources const& res, \ const cuvs::neighbors::brute_force::search_params& params, \ @@ -220,7 +220,7 @@ void index::update_dataset( const cuvs::neighbors::filtering::base_filter& sample_filter) \ { \ detail::search( \ - res, idx, queries, neighbors, distances, sample_filter); \ + res, params, idx, queries, neighbors, distances, sample_filter); \ } \ void search(raft::resources const& res, \ const cuvs::neighbors::brute_force::index& idx, \ @@ -230,7 +230,7 @@ void index::update_dataset( const cuvs::neighbors::filtering::base_filter& sample_filter) \ { \ detail::search( \ - res, idx, queries, neighbors, distances, sample_filter); \ + res, {}, idx, queries, neighbors, distances, sample_filter); \ } CUVS_INST_BFKNN(float, float); diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index 34ef8dd937..63d6c0210e 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -584,6 +584,7 @@ void brute_force_search( template void brute_force_search_filtered( raft::resources const& res, + const cuvs::neighbors::brute_force::search_params& params, const cuvs::neighbors::brute_force::index& idx, raft::device_matrix_view queries, const cuvs::neighbors::filtering::base_filter* filter, @@ -619,8 +620,10 @@ void brute_force_search_filtered( const cuvs::core::bitset_view>> filter_view; - IdxT nnz_h = 0; - float sparsity = 0.0f; + IdxT nnz_h = 0; + float sparsity = 0.0f; + bool nnz_h_is_set = false; + const bool use_hint = params.filtering_rate >= 0.0f; const BitsT* filter_data = nullptr; @@ -628,14 +631,24 @@ void brute_force_search_filtered( auto actual_filter = dynamic_cast*>(filter); filter_view.emplace(actual_filter->view()); - nnz_h = actual_filter->view().count(res); - sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset); + if (use_hint) { + sparsity = params.filtering_rate; + } else { + nnz_h = actual_filter->view().count(res); + sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset); + nnz_h_is_set = true; + } } else if (filter_type == cuvs::neighbors::filtering::FilterType::Bitset) { auto actual_filter = dynamic_cast*>(filter); filter_view.emplace(actual_filter->view()); - nnz_h = n_queries * actual_filter->view().count(res); - sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset); + if (use_hint) { + sparsity = params.filtering_rate; + } else { + nnz_h = n_queries * actual_filter->view().count(res); + sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset); + nnz_h_is_set = true; + } } else { RAFT_FAIL("Unsupported sample filter type"); } @@ -666,6 +679,19 @@ void brute_force_search_filtered( raft::identity_op(), filter_type); } else { + // The CSR path needs an exact non-zero count to size the matrix. If the hint was used, + // popcount lazily here — we still save the kernel + sync on the (more common) dense path. + if (!nnz_h_is_set) { + if (filter_type == cuvs::neighbors::filtering::FilterType::Bitmap) { + auto actual_filter = + dynamic_cast*>(filter); + nnz_h = actual_filter->view().count(res); + } else { + auto actual_filter = + dynamic_cast*>(filter); + nnz_h = n_queries * actual_filter->view().count(res); + } + } auto csr = raft::make_device_csr_matrix(res, n_queries, n_dataset, nnz_h); std::visit([&](const auto& actual_view) { actual_view.to_csr(res, csr); }, *filter_view); @@ -739,6 +765,7 @@ void brute_force_search_filtered( template void search(raft::resources const& res, + const cuvs::neighbors::brute_force::search_params& params, const cuvs::neighbors::brute_force::index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, @@ -759,7 +786,7 @@ void search(raft::resources const& res, dynamic_cast&>( sample_filter_ref); return brute_force_search_filtered( - res, idx, queries, &sample_filter, neighbors, distances); + res, params, idx, queries, &sample_filter, neighbors, distances); } catch (const std::bad_cast&) { } @@ -768,7 +795,7 @@ void search(raft::resources const& res, dynamic_cast&>( sample_filter_ref); return brute_force_search_filtered( - res, idx, queries, &sample_filter, neighbors, distances); + res, params, idx, queries, &sample_filter, neighbors, distances); } catch (const std::bad_cast&) { RAFT_FAIL("Unsupported sample filter type"); } diff --git a/cpp/tests/neighbors/brute_force_prefiltered.cu b/cpp/tests/neighbors/brute_force_prefiltered.cu index c15f08363f..7553490acc 100644 --- a/cpp/tests/neighbors/brute_force_prefiltered.cu +++ b/cpp/tests/neighbors/brute_force_prefiltered.cu @@ -520,6 +520,50 @@ class PrefilteredBruteForceOnBitmapTest true)); } + // Same as Run(), but passes the true sparsity as a filtering_rate hint to the + // params-taking search overload. Confirms results match auto-detection. + void RunWithFilteringRateHint() + { + auto dataset_raw = raft::make_device_matrix_view( + (const value_t*)dataset_d.data(), params.n_dataset, params.dim); + + auto queries = raft::make_device_matrix_view( + (const value_t*)queries_d.data(), params.n_queries, params.dim); + + auto dataset = brute_force::build(handle, dataset_raw, params.metric); + + auto filter = cuvs::core::bitmap_view( + (bitmap_t*)filter_d.data(), params.n_queries, params.n_dataset); + + auto out_val = raft::make_device_matrix_view( + out_val_d.data(), params.n_queries, params.top_k); + auto out_idx = raft::make_device_matrix_view( + out_idx_d.data(), params.n_queries, params.top_k); + + cuvs::neighbors::brute_force::search_params search_params; + search_params.filtering_rate = params.sparsity; + + brute_force::search(handle, + search_params, + dataset, + queries, + out_idx, + out_val, + cuvs::neighbors::filtering::bitmap_filter(filter)); + + raft::resource::sync_stream(handle); + + ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(out_idx_expected_d.data(), + out_idx.data_handle(), + out_val_expected_d.data(), + out_val.data_handle(), + params.n_queries, + params.top_k, + 0.001f, + stream, + true)); + } + protected: raft::resources handle; cudaStream_t stream; @@ -941,6 +985,50 @@ class PrefilteredBruteForceOnBitsetTest true)); } + // Same as Run(), but passes the true sparsity as a filtering_rate hint to the + // params-taking search overload. Confirms results match auto-detection. + void RunWithFilteringRateHint() + { + auto dataset_raw = raft::make_device_matrix_view( + (const value_t*)dataset_d.data(), params.n_dataset, params.dim); + + auto queries = raft::make_device_matrix_view( + (const value_t*)queries_d.data(), params.n_queries, params.dim); + + auto dataset = brute_force::build(handle, dataset_raw, params.metric); + + auto filter = + cuvs::core::bitset_view((bitset_t*)filter_d.data(), params.n_dataset); + + auto out_val = raft::make_device_matrix_view( + out_val_d.data(), params.n_queries, params.top_k); + auto out_idx = raft::make_device_matrix_view( + out_idx_d.data(), params.n_queries, params.top_k); + + cuvs::neighbors::brute_force::search_params search_params; + search_params.filtering_rate = params.sparsity; + + brute_force::search(handle, + search_params, + dataset, + queries, + out_idx, + out_val, + cuvs::neighbors::filtering::bitset_filter(filter)); + + raft::resource::sync_stream(handle); + + ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(out_idx_expected_d.data(), + out_idx.data_handle(), + out_val_expected_d.data(), + out_val.data_handle(), + params.n_queries, + params.top_k, + 0.001f, + stream, + true)); + } + protected: raft::resources handle; cudaStream_t stream; @@ -963,18 +1051,34 @@ class PrefilteredBruteForceOnBitsetTest using PrefilteredBruteForceTestOnBitmap_float_int64 = PrefilteredBruteForceOnBitmapTest; TEST_P(PrefilteredBruteForceTestOnBitmap_float_int64, Result) { Run(); } +TEST_P(PrefilteredBruteForceTestOnBitmap_float_int64, ResultWithFilteringRateHint) +{ + RunWithFilteringRateHint(); +} using PrefilteredBruteForceTestOnBitmap_half_int64 = PrefilteredBruteForceOnBitmapTest; TEST_P(PrefilteredBruteForceTestOnBitmap_half_int64, Result) { Run(); } +TEST_P(PrefilteredBruteForceTestOnBitmap_half_int64, ResultWithFilteringRateHint) +{ + RunWithFilteringRateHint(); +} using PrefilteredBruteForceTestOnBitset_float_int64 = PrefilteredBruteForceOnBitsetTest; TEST_P(PrefilteredBruteForceTestOnBitset_float_int64, Result) { Run(); } +TEST_P(PrefilteredBruteForceTestOnBitset_float_int64, ResultWithFilteringRateHint) +{ + RunWithFilteringRateHint(); +} using PrefilteredBruteForceTestOnBitset_half_int64 = PrefilteredBruteForceOnBitsetTest; TEST_P(PrefilteredBruteForceTestOnBitset_half_int64, Result) { Run(); } +TEST_P(PrefilteredBruteForceTestOnBitset_half_int64, ResultWithFilteringRateHint) +{ + RunWithFilteringRateHint(); +} template const std::vector> selectk_inputs = {