Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,9 @@ struct search_params : cuvs::neighbors::search_params {
*/
float persistent_device_usage = 1.0;

/**
* A parameter indicating the rate of nodes to be filtered-out, when filtering is used.
* The value must be equal to or greater than 0.0 and less than 1.0. Default value is
* negative, in which case the filtering rate is automatically calculated.
*/
float filtering_rate = -1.0;
// `filtering_rate` is inherited from `cuvs::neighbors::search_params`. CAGRA uses it to
// size `itopk_size`; supplying a non-negative value avoids a per-search popcount + host sync
// on the `bitset_filter` path.
};

/**
Expand Down
18 changes: 17 additions & 1 deletion cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,23 @@ struct index_params {
float metric_arg = 2.0f;
};

struct search_params {};
struct search_params {
/**
* A hint indicating the rate at which the sample filter is expected to filter out items
* (i.e. `(n_dataset - n_set_bits) / n_dataset` for a `bitset_filter`).
*
* Algorithms that benefit from knowing the filter's selectivity may use this hint to tune
* internal parameters or skip a hidden popcount kernel + host sync that would otherwise be
* required to derive it from the filter on every search call.
*
* - Negative (default): the algorithm auto-detects the rate from the filter when needed.
* This launches a GPU popcount reduction and synchronizes the stream per search call.
* - In `[0.0, 1.0)`: the algorithm trusts the supplied value and skips the auto-detection.
*
* Algorithms that do not use this hint (e.g. `ivf_flat`, `ivf_pq`) ignore it.
*/
float filtering_rate = -1.0;
};

/**
* @brief Strategy for merging indices.
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void index<T, DistT>::update_dataset(
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
detail::search<T, int64_t, DistT, raft::row_major>( \
res, idx, queries, neighbors, distances, sample_filter); \
res, params, idx, queries, neighbors, distances, sample_filter); \
} \
void search(raft::resources const& res, \
const cuvs::neighbors::brute_force::index<T, DistT>& idx, \
Expand All @@ -209,7 +209,7 @@ void index<T, DistT>::update_dataset(
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
detail::search<T, int64_t, DistT, raft::row_major>( \
res, idx, queries, neighbors, distances, sample_filter); \
res, {}, idx, queries, neighbors, distances, sample_filter); \
} \
void search(raft::resources const& res, \
const cuvs::neighbors::brute_force::search_params& params, \
Expand All @@ -220,7 +220,7 @@ void index<T, DistT>::update_dataset(
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
detail::search<T, int64_t, DistT, raft::col_major>( \
res, idx, queries, neighbors, distances, sample_filter); \
res, params, idx, queries, neighbors, distances, sample_filter); \
} \
void search(raft::resources const& res, \
const cuvs::neighbors::brute_force::index<T, DistT>& idx, \
Expand All @@ -230,7 +230,7 @@ void index<T, DistT>::update_dataset(
const cuvs::neighbors::filtering::base_filter& sample_filter) \
{ \
detail::search<T, int64_t, DistT, raft::col_major>( \
res, idx, queries, neighbors, distances, sample_filter); \
res, {}, idx, queries, neighbors, distances, sample_filter); \
}

CUVS_INST_BFKNN(float, float);
Expand Down
43 changes: 35 additions & 8 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ void brute_force_search(
template <typename T, typename IdxT, typename BitsT, typename DistanceT = float>
void brute_force_search_filtered(
raft::resources const& res,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<T, DistanceT>& idx,
raft::device_matrix_view<const T, IdxT, raft::row_major> queries,
const cuvs::neighbors::filtering::base_filter* filter,
Expand Down Expand Up @@ -619,23 +620,35 @@ void brute_force_search_filtered(
const cuvs::core::bitset_view<BitsT, IdxT>>>
filter_view;

IdxT nnz_h = 0;
float sparsity = 0.0f;
IdxT nnz_h = 0;
float sparsity = 0.0f;
bool nnz_h_is_set = false;
const bool use_hint = params.filtering_rate >= 0.0f;

Comment on lines +623 to 627
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Validate filtering_rate before treating it as a hint.

Line 626 currently accepts any non-negative value as valid. That allows out-of-contract inputs (e.g., >= 1.0) to silently affect path selection. Add an explicit range check and only enable hint mode for [0.0, 1.0).

🔧 Proposed fix
-  IdxT nnz_h          = 0;
-  float sparsity      = 0.0f;
-  bool nnz_h_is_set   = false;
-  const bool use_hint = params.filtering_rate >= 0.0f;
+  const bool is_auto = params.filtering_rate < 0.0f;
+  const bool is_hint = params.filtering_rate >= 0.0f && params.filtering_rate < 1.0f;
+  RAFT_EXPECTS(is_auto || is_hint,
+               "brute_force::search_params::filtering_rate must be negative (auto) or in [0, 1).");
+
+  IdxT nnz_h          = 0;
+  float sparsity      = 0.0f;
+  bool nnz_h_is_set   = false;
+  const bool use_hint = is_hint;

As per coding guidelines “Input validation must check for negative or invalid dimensions, null pointers, and invalid parameter combinations before GPU operations”.

Also applies to: 634-650

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/neighbors/detail/knn_brute_force.cuh` around lines 623 - 627, The
code currently enables hint mode for any non-negative params.filtering_rate;
change this to only enable hint mode when filtering_rate is in [0.0, 1.0) by
replacing the use_hint initialization with an explicit range check (e.g., const
bool use_hint = (params.filtering_rate >= 0.0f && params.filtering_rate <
1.0f)); additionally add an explicit validation branch near that declaration (in
the knn_brute_force.cuh scope where use_hint and params.filtering_rate are used)
that treats values >= 1.0f as invalid—either disable the hint and emit a clear
error/exception (e.g., throw std::invalid_argument or return an error) or log a
warning before proceeding—so out-of-contract inputs do not silently alter path
selection.

const BitsT* filter_data = nullptr;

if (filter_type == cuvs::neighbors::filtering::FilterType::Bitmap) {
auto actual_filter =
dynamic_cast<const cuvs::neighbors::filtering::bitmap_filter<BitsT, int64_t>*>(filter);
filter_view.emplace(actual_filter->view());
nnz_h = actual_filter->view().count(res);
sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset);
if (use_hint) {
sparsity = params.filtering_rate;
} else {
nnz_h = actual_filter->view().count(res);
sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset);
nnz_h_is_set = true;
}
} else if (filter_type == cuvs::neighbors::filtering::FilterType::Bitset) {
auto actual_filter =
dynamic_cast<const cuvs::neighbors::filtering::bitset_filter<BitsT, int64_t>*>(filter);
filter_view.emplace(actual_filter->view());
nnz_h = n_queries * actual_filter->view().count(res);
sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset);
if (use_hint) {
sparsity = params.filtering_rate;
} else {
nnz_h = n_queries * actual_filter->view().count(res);
sparsity = 1.0 - nnz_h / (1.0 * n_queries * n_dataset);
nnz_h_is_set = true;
}
} else {
RAFT_FAIL("Unsupported sample filter type");
}
Expand Down Expand Up @@ -666,6 +679,19 @@ void brute_force_search_filtered(
raft::identity_op(),
filter_type);
} else {
// The CSR path needs an exact non-zero count to size the matrix. If the hint was used,
// popcount lazily here — we still save the kernel + sync on the (more common) dense path.
if (!nnz_h_is_set) {
if (filter_type == cuvs::neighbors::filtering::FilterType::Bitmap) {
auto actual_filter =
dynamic_cast<const cuvs::neighbors::filtering::bitmap_filter<BitsT, int64_t>*>(filter);
nnz_h = actual_filter->view().count(res);
} else {
auto actual_filter =
dynamic_cast<const cuvs::neighbors::filtering::bitset_filter<BitsT, int64_t>*>(filter);
nnz_h = n_queries * actual_filter->view().count(res);
}
}
auto csr = raft::make_device_csr_matrix<DistanceT, IdxT>(res, n_queries, n_dataset, nnz_h);
std::visit([&](const auto& actual_view) { actual_view.to_csr(res, csr); }, *filter_view);

Expand Down Expand Up @@ -739,6 +765,7 @@ void brute_force_search_filtered(

template <typename T, typename IdxT, typename DistT, typename LayoutT>
void search(raft::resources const& res,
const cuvs::neighbors::brute_force::search_params& params,
const cuvs::neighbors::brute_force::index<T, DistT>& idx,
raft::device_matrix_view<const T, int64_t, LayoutT> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
Expand All @@ -759,7 +786,7 @@ void search(raft::resources const& res,
dynamic_cast<const cuvs::neighbors::filtering::bitmap_filter<uint32_t, int64_t>&>(
sample_filter_ref);
return brute_force_search_filtered<T, int64_t, uint32_t, DistT>(
res, idx, queries, &sample_filter, neighbors, distances);
res, params, idx, queries, &sample_filter, neighbors, distances);
} catch (const std::bad_cast&) {
}

Expand All @@ -768,7 +795,7 @@ void search(raft::resources const& res,
dynamic_cast<const cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>&>(
sample_filter_ref);
return brute_force_search_filtered<T, int64_t, uint32_t, DistT>(
res, idx, queries, &sample_filter, neighbors, distances);
res, params, idx, queries, &sample_filter, neighbors, distances);
} catch (const std::bad_cast&) {
RAFT_FAIL("Unsupported sample filter type");
}
Expand Down
104 changes: 104 additions & 0 deletions cpp/tests/neighbors/brute_force_prefiltered.cu
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,50 @@ class PrefilteredBruteForceOnBitmapTest
true));
}

// Same as Run(), but passes the true sparsity as a filtering_rate hint to the
// params-taking search overload. Confirms results match auto-detection.
void RunWithFilteringRateHint()
{
auto dataset_raw = raft::make_device_matrix_view<const value_t, index_t, raft::row_major>(
(const value_t*)dataset_d.data(), params.n_dataset, params.dim);

auto queries = raft::make_device_matrix_view<const value_t, index_t, raft::row_major>(
(const value_t*)queries_d.data(), params.n_queries, params.dim);

auto dataset = brute_force::build(handle, dataset_raw, params.metric);

auto filter = cuvs::core::bitmap_view<bitmap_t, index_t>(
(bitmap_t*)filter_d.data(), params.n_queries, params.n_dataset);

auto out_val = raft::make_device_matrix_view<dist_t, index_t, raft::row_major>(
out_val_d.data(), params.n_queries, params.top_k);
auto out_idx = raft::make_device_matrix_view<index_t, index_t, raft::row_major>(
out_idx_d.data(), params.n_queries, params.top_k);

cuvs::neighbors::brute_force::search_params search_params;
search_params.filtering_rate = params.sparsity;

brute_force::search(handle,
search_params,
dataset,
queries,
out_idx,
out_val,
cuvs::neighbors::filtering::bitmap_filter(filter));

raft::resource::sync_stream(handle);

ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(out_idx_expected_d.data(),
out_idx.data_handle(),
out_val_expected_d.data(),
out_val.data_handle(),
params.n_queries,
params.top_k,
0.001f,
stream,
true));
}

protected:
raft::resources handle;
cudaStream_t stream;
Expand Down Expand Up @@ -941,6 +985,50 @@ class PrefilteredBruteForceOnBitsetTest
true));
}

// Same as Run(), but passes the true sparsity as a filtering_rate hint to the
// params-taking search overload. Confirms results match auto-detection.
void RunWithFilteringRateHint()
{
auto dataset_raw = raft::make_device_matrix_view<const value_t, index_t, raft::row_major>(
(const value_t*)dataset_d.data(), params.n_dataset, params.dim);

auto queries = raft::make_device_matrix_view<const value_t, index_t, raft::row_major>(
(const value_t*)queries_d.data(), params.n_queries, params.dim);

auto dataset = brute_force::build(handle, dataset_raw, params.metric);

auto filter =
cuvs::core::bitset_view<bitset_t, index_t>((bitset_t*)filter_d.data(), params.n_dataset);

auto out_val = raft::make_device_matrix_view<dist_t, index_t, raft::row_major>(
out_val_d.data(), params.n_queries, params.top_k);
auto out_idx = raft::make_device_matrix_view<index_t, index_t, raft::row_major>(
out_idx_d.data(), params.n_queries, params.top_k);

cuvs::neighbors::brute_force::search_params search_params;
search_params.filtering_rate = params.sparsity;

brute_force::search(handle,
search_params,
dataset,
queries,
out_idx,
out_val,
cuvs::neighbors::filtering::bitset_filter(filter));

raft::resource::sync_stream(handle);

ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(out_idx_expected_d.data(),
out_idx.data_handle(),
out_val_expected_d.data(),
out_val.data_handle(),
params.n_queries,
params.top_k,
0.001f,
stream,
true));
}

protected:
raft::resources handle;
cudaStream_t stream;
Expand All @@ -963,18 +1051,34 @@ class PrefilteredBruteForceOnBitsetTest
using PrefilteredBruteForceTestOnBitmap_float_int64 =
PrefilteredBruteForceOnBitmapTest<float, float, int64_t>;
TEST_P(PrefilteredBruteForceTestOnBitmap_float_int64, Result) { Run(); }
TEST_P(PrefilteredBruteForceTestOnBitmap_float_int64, ResultWithFilteringRateHint)
{
RunWithFilteringRateHint();
}

using PrefilteredBruteForceTestOnBitmap_half_int64 =
PrefilteredBruteForceOnBitmapTest<half, float, int64_t>;
TEST_P(PrefilteredBruteForceTestOnBitmap_half_int64, Result) { Run(); }
TEST_P(PrefilteredBruteForceTestOnBitmap_half_int64, ResultWithFilteringRateHint)
{
RunWithFilteringRateHint();
}

using PrefilteredBruteForceTestOnBitset_float_int64 =
PrefilteredBruteForceOnBitsetTest<float, float, int64_t>;
TEST_P(PrefilteredBruteForceTestOnBitset_float_int64, Result) { Run(); }
TEST_P(PrefilteredBruteForceTestOnBitset_float_int64, ResultWithFilteringRateHint)
{
RunWithFilteringRateHint();
}

using PrefilteredBruteForceTestOnBitset_half_int64 =
PrefilteredBruteForceOnBitsetTest<half, float, int64_t>;
TEST_P(PrefilteredBruteForceTestOnBitset_half_int64, Result) { Run(); }
TEST_P(PrefilteredBruteForceTestOnBitset_half_int64, ResultWithFilteringRateHint)
{
RunWithFilteringRateHint();
}

template <typename index_t>
const std::vector<PrefilteredBruteForceInputs<index_t>> selectk_inputs = {
Expand Down
Loading