From 7c7070a297f2f4a1f9419fe46358de671af5eea6 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Tue, 26 May 2026 20:23:44 -0500 Subject: [PATCH 01/11] FEA first filter udf commit --- .../cuvs/detail/jit_lto/common_fragments.hpp | 1 + cpp/include/cuvs/neighbors/cagra.hpp | 5 +- cpp/include/cuvs/neighbors/common.hpp | 48 ++- cpp/src/neighbors/cagra.cuh | 21 ++ .../jit_lto_kernels/apply_filter_kernel.cu.in | 6 +- .../cagra/jit_lto_kernels/cagra_bitset.cuh | 63 +++- .../cagra_jit_launcher_factory.hpp | 75 +++-- .../jit_lto_kernels/cagra_planner_base.hpp | 13 +- ...mpute_distance_to_child_nodes_kernel.cu.in | 6 +- .../cagra/jit_lto_kernels/kernel_def.hpp | 10 +- .../jit_lto_kernels/sample_filter_udf.cuh | 91 ++++++ .../jit_lto_kernels/search_multi_cta_jit.cuh | 6 +- .../search_multi_cta_kernel.cu.in | 6 +- .../jit_lto_kernels/search_multi_jit.cuh | 50 +-- .../jit_lto_kernels/search_single_cta_jit.cuh | 14 +- .../search_single_cta_kernel.cu.in | 6 +- .../search_single_cta_p_kernel.cu.in | 6 +- .../detail/cagra/search_multi_cta_inst.cu.in | 3 + .../search_multi_cta_kernel_launcher_jit.cuh | 5 +- .../detail/cagra/search_multi_kernel.cuh | 7 +- .../search_multi_kernel_launcher_jit.cuh | 9 +- .../detail/cagra/search_single_cta_inst.cu.in | 3 + .../search_single_cta_kernel_launcher_jit.cuh | 68 +++- .../detail/cagra/shared_launcher_jit.hpp | 21 +- cpp/tests/CMakeLists.txt | 7 + .../neighbors/ann_cagra/test_filter_udf.cu | 299 ++++++++++++++++++ docs/source/filtering.rst | 139 ++++++++ docs/source/neighbors/cagra.rst | 4 + examples/cpp/CMakeLists.txt | 4 + examples/cpp/src/cagra_filter_udf_example.cu | 251 +++++++++++++++ 30 files changed, 1125 insertions(+), 122 deletions(-) create mode 100644 cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh create mode 100644 cpp/tests/neighbors/ann_cagra/test_filter_udf.cu create mode 100644 examples/cpp/src/cagra_filter_udf_example.cu diff --git a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp index cb33e4109b..a82b6bcc1e 100644 --- a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp @@ -13,6 +13,7 @@ struct tag_i8 {}; struct tag_u8 {}; struct tag_filter_none {}; struct tag_filter_bitset {}; +struct tag_filter_udf {}; struct tag_bitset_u32 {}; diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 6a5d15bc59..39599c6070 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -346,7 +346,10 @@ struct search_params : cuvs::neighbors::search_params { /** * A parameter indicating the rate of nodes to be filtered-out, when filtering is used. * The value must be equal to or greater than 0.0 and less than 1.0. Default value is - * negative, in which case the filtering rate is automatically calculated. + * negative, in which case the filtering rate is automatically calculated when possible. + * For `filtering::udf_filter`, CAGRA uses `udf_filter::filtering_rate` when this value is + * negative. If both values are negative, CAGRA assumes 0.0 because a UDF's selectivity cannot be + * inferred from the source string. */ float filtering_rate = -1.0; }; diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 7569ee2508..fae3afd58c 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -23,7 +23,9 @@ #include #include +#include #include +#include #ifdef __cpp_lib_bitops #include @@ -495,7 +497,7 @@ namespace filtering { * @{ */ -enum class FilterType { None, Bitmap, Bitset }; +enum class FilterType { None, Bitmap, Bitset, UDF }; struct base_filter { ~base_filter() = default; @@ -615,6 +617,50 @@ struct bitset_filter : public base_filter { void to_csr(raft::resources const& handle, csr_matrix_t& csr); }; +/** + * @brief JIT-LTO user-defined filter predicate. + * + * The source must define a device function named by @c function_name with signature: + * + * @code{.cpp} + * __device__ bool cuvs_filter_udf(uint32_t query_id, source_index_t source_id, void* filter_data); + * @endcode + * + * Return @c true to allow a source vector to appear in the results and @c false to reject it. + * @c filter_data is passed through unchanged and must point to device-accessible memory when the + * UDF dereferences it. CAGRA currently provides @c source_index_t as @c uint32_t in the generated + * JIT fragment. + */ +struct udf_filter : public base_filter { + /** CUDA C++ source containing the device predicate. */ + std::string source; + /** Opaque device-accessible pointer passed to the predicate. */ + void* filter_data = nullptr; + /** Estimated fraction of rows rejected by the predicate, or negative if unknown. */ + float filtering_rate = -1.0f; + /** Optional stable cache key for equivalent generated source. */ + std::string cache_key; + /** Device function name to call from the generated CAGRA sample filter. */ + std::string function_name = "cuvs_filter_udf"; + + udf_filter() = default; + + udf_filter(std::string source, + void* filter_data = nullptr, + float filtering_rate = -1.0f, + std::string cache_key = {}, + std::string function_name = "cuvs_filter_udf") + : source(std::move(source)), + filter_data(filter_data), + filtering_rate(filtering_rate), + cache_key(std::move(cache_key)), + function_name(std::move(function_name)) + { + } + + FilterType get_filter_type() const override { return FilterType::UDF; } +}; + /** @} */ // end group neighbors_filtering /** diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index 73c3794d39..ffe33ecd48 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -26,6 +26,8 @@ #include +#include + namespace cuvs::neighbors::cagra { // Member function implementations for cagra::index @@ -380,6 +382,25 @@ void search(raft::resources const& res, auto sample_filter_copy = sample_filter; return search_with_filtering( res, params_copy, idx, queries, neighbors, distances, sample_filter_copy); + } catch (const std::bad_cast&) { + } + + try { + auto& sample_filter = + dynamic_cast(sample_filter_ref); + search_params params_copy = params; + if (params.filtering_rate < 0.0) { + const float min_filtering_rate = 0.0f; + const float max_filtering_rate = 0.999f; + params_copy.filtering_rate = sample_filter.filtering_rate < 0.0f + ? 0.0f + : std::min(std::max(sample_filter.filtering_rate, + min_filtering_rate), + max_filtering_rate); + } + auto sample_filter_copy = sample_filter; + return search_with_filtering( + res, params_copy, idx, queries, neighbors, distances, sample_filter_copy); } catch (const std::bad_cast&) { RAFT_FAIL("Unsupported sample filter type"); } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in index 49d5d2fa07..c6e4c49df8 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in @@ -11,7 +11,7 @@ namespace { using index_t = @index_type@; using distance_t = @distance_type@; using source_index_t = @source_index_type@; -using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; +using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace @@ -24,7 +24,7 @@ extern "C" __global__ void apply_filter_kernel(const source_index_t* const sourc const std::uint32_t result_buffer_size, const std::uint32_t num_queries, const std::uint32_t query_id_offset, - cagra_bitset_t bitset) + cagra_sample_filter_t filter_payload) { apply_filter_kernel_jit(source_indices_ptr, result_indices_ptr, @@ -33,7 +33,7 @@ extern "C" __global__ void apply_filter_kernel(const source_index_t* const sourc result_buffer_size, num_queries, query_id_offset, - bitset); + filter_payload); } static_assert(std::is_same_v @@ -15,10 +15,14 @@ namespace cuvs::neighbors::cagra::detail { template using cagra_bitset = cuvs::neighbors::detail::bitset_filter_data_t; -/// Host: bitset payload for kernels plus query offset for wrapped CAGRA filters. +enum class cagra_filter_kind : std::uint32_t { none = 0, bitset = 1, udf = 2 }; + +/// Host/device payload for linked CAGRA sample filters plus query offset for wrapped filters. template struct cagra_sample_filter { cagra_bitset bitset{}; + void* filter_data{nullptr}; + cagra_filter_kind filter_kind{cagra_filter_kind::none}; std::uint32_t query_id_offset{0}; }; @@ -29,6 +33,28 @@ template struct is_bitset_filter> : std::true_type {}; +template +struct is_udf_filter : std::false_type {}; + +template <> +struct is_udf_filter : std::true_type {}; + +template +void fill_cagra_sample_filter(cagra_sample_filter& out, const FilterT& filter) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_bitset_filter::value) { + const auto bitset_view = filter.view(); + out.bitset.bitset_ptr = const_cast(bitset_view.data()); + out.bitset.bitset_len = static_cast(bitset_view.size()); + out.bitset.original_nbits = static_cast(bitset_view.get_original_nbits()); + out.filter_kind = cagra_filter_kind::bitset; + } else if constexpr (is_udf_filter::value) { + out.filter_data = filter.filter_data; + out.filter_kind = cagra_filter_kind::udf; + } +} + /// Host: fill @ref cagra_sample_filter from a CAGRA filter object (used by JIT LTO launchers). template cagra_sample_filter extract_cagra_sample_filter(const SampleFilterT& sample_filter) @@ -39,15 +65,34 @@ cagra_sample_filter extract_cagra_sample_filter(const SampleFilter sample_filter.offset; }) { out.query_id_offset = sample_filter.offset; - using InnerFilter = decltype(sample_filter.filter); - if constexpr (is_bitset_filter::value) { - const auto bitset_view = sample_filter.filter.view(); - out.bitset.bitset_ptr = const_cast(bitset_view.data()); - out.bitset.bitset_len = static_cast(bitset_view.size()); - out.bitset.original_nbits = static_cast(bitset_view.get_original_nbits()); - } + fill_cagra_sample_filter(out, sample_filter.filter); + } else { + fill_cagra_sample_filter(out, sample_filter); } return out; } +template +const cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( + const SampleFilterT& sample_filter) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_udf_filter::value) { + return &sample_filter; + } else if constexpr (requires { sample_filter.filter; }) { + return get_cagra_udf_filter(sample_filter.filter); + } else { + return nullptr; + } +} + +template +__device__ __forceinline__ void* get_cagra_sample_filter_data( + cagra_sample_filter& payload) +{ + if (payload.filter_kind == cagra_filter_kind::udf) { return payload.filter_data; } + if (payload.filter_kind == cagra_filter_kind::bitset) { return &payload.bitset; } + return nullptr; +} + } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp index 60d17c5128..997f916d78 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp @@ -10,6 +10,7 @@ #include "search_multi_cta_planner.hpp" #include "search_multi_kernel_planner.hpp" #include "search_single_cta_planner.hpp" +#include "sample_filter_udf.cuh" #include #include @@ -37,7 +38,8 @@ std::shared_ptr build_single_cta_launcher( const dataset_descriptor_host& dataset_desc, bool topk_by_bitonic_sort, bool bitonic_sort_and_merge_multi_warps, - bool persistent) + bool persistent, + std::unique_ptr sample_filter_udf_fragment) { single_cta_search::CagraSingleCtaSearchPlanner build_single_cta_launcher( } planner.add_search_kernel_fragment( topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); - planner.add_sample_filter_device_function(); + planner.add_sample_filter_device_function(std::move(sample_filter_udf_fragment)); return planner.get_launcher(); } @@ -85,7 +87,8 @@ template std::shared_ptr build_multi_cta_launcher( - const dataset_descriptor_host& dataset_desc) + const dataset_descriptor_host& dataset_desc, + std::unique_ptr sample_filter_udf_fragment) { multi_cta_search::CagraMultiCtaSearchPlanner build_multi_cta_launcher( dataset_desc.metric, dataset_desc.team_size, dataset_desc.dataset_block_dim); } planner.add_search_multi_cta_kernel_fragment(); - planner.add_sample_filter_device_function(); + planner.add_sample_filter_device_function(std::move(sample_filter_udf_fragment)); return planner.get_launcher(); } @@ -130,7 +133,8 @@ template std::shared_ptr build_multi_kernel_launcher( const dataset_descriptor_host& dataset_desc, - const char* linked_kernel_name) + const char* linked_kernel_name, + std::unique_ptr sample_filter_udf_fragment) { multi_kernel_search::CagraMultiKernelSearchPlanner build_multi_kernel_launcher( planner.add_compute_distance_device_function( dataset_desc.metric, dataset_desc.team_size, dataset_desc.dataset_block_dim); } - planner.add_sample_filter_device_function(); + planner.add_sample_filter_device_function(std::move(sample_filter_udf_fragment)); planner.add_linked_kernel(linked_kernel_name); return planner.get_launcher(); } @@ -175,7 +179,8 @@ template -std::shared_ptr build_apply_filter_only_launcher() +std::shared_ptr build_apply_filter_only_launcher( + std::unique_ptr sample_filter_udf_fragment) { multi_kernel_search::CagraMultiKernelSearchPlanner build_apply_filter_only_launcher() CodebookTag, SampleFilterJitTag> planner("apply_filter_kernel"); - planner.add_sample_filter_device_function(); + planner.add_sample_filter_device_function(std::move(sample_filter_udf_fragment)); planner.add_linked_kernel("apply_filter_kernel"); return planner.get_launcher(); } @@ -204,7 +209,8 @@ std::shared_ptr make_cagra_single_cta_jit_launcher( const dataset_descriptor_host& dataset_desc, bool topk_by_bitonic_sort, bool bitonic_sort_and_merge_multi_warps, - bool persistent) + bool persistent, + std::unique_ptr sample_filter_udf_fragment = nullptr) { using DataTag = decltype(get_data_type_tag()); using IndexTag = decltype(get_index_type_tag()); @@ -225,7 +231,11 @@ std::shared_ptr make_cagra_single_cta_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + dataset_desc, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + persistent, + std::move(sample_filter_udf_fragment)); } using CodebookTag = codebook_tag_standard_t; if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { @@ -242,7 +252,11 @@ std::shared_ptr make_cagra_single_cta_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + dataset_desc, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + persistent, + std::move(sample_filter_udf_fragment)); } using QueryTag = query_type_tag_standard_t; return cagra_jit_launcher_factory_detail::build_single_cta_launcher make_cagra_single_cta_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + dataset_desc, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + persistent, + std::move(sample_filter_udf_fragment)); } /// Build a JIT AlgorithmLauncher for multi-CTA CAGRA search. @@ -266,7 +284,8 @@ template std::shared_ptr make_cagra_multi_cta_jit_launcher( - const dataset_descriptor_host& dataset_desc) + const dataset_descriptor_host& dataset_desc, + std::unique_ptr sample_filter_udf_fragment = nullptr) { using DataTag = decltype(get_data_type_tag()); using IndexTag = decltype(get_index_type_tag()); @@ -286,7 +305,8 @@ std::shared_ptr make_cagra_multi_cta_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(dataset_desc); + SourceIndexT>( + dataset_desc, std::move(sample_filter_udf_fragment)); } using CodebookTag = codebook_tag_standard_t; if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { @@ -302,7 +322,8 @@ std::shared_ptr make_cagra_multi_cta_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(dataset_desc); + SourceIndexT>( + dataset_desc, std::move(sample_filter_udf_fragment)); } using QueryTag = query_type_tag_standard_t; return cagra_jit_launcher_factory_detail::build_multi_cta_launcher make_cagra_multi_cta_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(dataset_desc); + SourceIndexT>( + dataset_desc, std::move(sample_filter_udf_fragment)); } /// Build a JIT AlgorithmLauncher for multi-kernel CAGRA helpers that need `setup_workspace` and @@ -331,7 +353,8 @@ template std::shared_ptr make_cagra_multi_kernel_jit_launcher( const dataset_descriptor_host& dataset_desc, - const char* linked_kernel_name) + const char* linked_kernel_name, + std::unique_ptr sample_filter_udf_fragment = nullptr) { using DataTag = decltype(get_data_type_tag()); using IndexTag = decltype(get_index_type_tag()); @@ -352,7 +375,7 @@ std::shared_ptr make_cagra_multi_kernel_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, linked_kernel_name); + dataset_desc, linked_kernel_name, std::move(sample_filter_udf_fragment)); } using CodebookTag = codebook_tag_standard_t; if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { @@ -369,7 +392,7 @@ std::shared_ptr make_cagra_multi_kernel_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, linked_kernel_name); + dataset_desc, linked_kernel_name, std::move(sample_filter_udf_fragment)); } using QueryTag = query_type_tag_standard_t; return cagra_jit_launcher_factory_detail::build_multi_kernel_launcher make_cagra_multi_kernel_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, linked_kernel_name); + dataset_desc, linked_kernel_name, std::move(sample_filter_udf_fragment)); } /// JIT launcher for the post-search `apply_filter_kernel` only (no workspace / distance fragments). @@ -396,7 +419,8 @@ template std::shared_ptr make_cagra_apply_filter_jit_launcher( - const dataset_descriptor_host& dataset_desc) + const dataset_descriptor_host& dataset_desc, + std::unique_ptr sample_filter_udf_fragment = nullptr) { using DataTag = decltype(get_data_type_tag()); using IndexTag = decltype(get_index_type_tag()); @@ -416,7 +440,8 @@ std::shared_ptr make_cagra_apply_filter_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(); + SourceIndexT>( + std::move(sample_filter_udf_fragment)); } using CodebookTag = codebook_tag_standard_t; if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { @@ -432,7 +457,8 @@ std::shared_ptr make_cagra_apply_filter_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(); + SourceIndexT>( + std::move(sample_filter_udf_fragment)); } using QueryTag = query_type_tag_standard_t; return cagra_jit_launcher_factory_detail::build_apply_filter_only_launcher make_cagra_apply_filter_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(); + SourceIndexT>( + std::move(sample_filter_udf_fragment)); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp index 317ca1a1b6..0c3ed64d13 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -9,9 +9,11 @@ #include #include #include +#include #include #include +#include #include #include @@ -250,9 +252,16 @@ struct CagraPlannerBase : AlgorithmPlanner { static_cast(dataset_block_dim)); } - void add_sample_filter_device_function() + void add_sample_filter_device_function(std::unique_ptr udf_fragment = nullptr) { - if constexpr (!std::is_same_v) { + if constexpr (std::is_same_v) { + RAFT_EXPECTS(udf_fragment == nullptr, "Unexpected CAGRA sample-filter UDF fragment"); + } else if constexpr (std::is_same_v) { + RAFT_EXPECTS(udf_fragment != nullptr, "CAGRA UDF filter requires a JIT-LTO fragment"); + this->add_fragment(std::move(udf_fragment)); + } else { + RAFT_EXPECTS(udf_fragment == nullptr, "Built-in CAGRA sample filters use static fragments"); this->add_static_fragment>(); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in index 0035ae1da0..d0cc88bc74 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in @@ -14,7 +14,7 @@ using distance_t = @distance_type@; using source_index_t = @source_index_type@; using dataset_desc_base = cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; -using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; +using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace @@ -35,7 +35,7 @@ extern "C" __global__ void compute_distance_to_child_nodes(const index_t* const index_t* const result_indices_ptr, distance_t* const result_distances_ptr, const std::uint32_t ldd, - cagra_bitset_t bitset) + cagra_sample_filter_t filter_payload) { compute_distance_to_child_nodes_kernel_jit( parent_node_list, @@ -53,7 +53,7 @@ extern "C" __global__ void compute_distance_to_child_nodes(const index_t* const result_indices_ptr, result_distances_ptr, ldd, - bitset); + filter_payload); } static_assert( diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp index 370dbd33d8..58c32707fe 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp @@ -47,7 +47,7 @@ using search_single_cta_kernel_func_t = const std::uint32_t, const dataset_descriptor_base_t*, const IndexT, - cagra_bitset); + cagra_sample_filter); namespace single_cta_search { @@ -76,7 +76,7 @@ using search_single_cta_p_kernel_func_t = const std::uint32_t, const std::uint32_t, const dataset_descriptor_base_t*, - cagra_bitset); + cagra_sample_filter); } // namespace single_cta_search @@ -105,7 +105,7 @@ using search_multi_cta_kernel_func_t = std::uint32_t* const, const IndexT, const std::uint32_t, - cagra_bitset); + cagra_sample_filter); } // namespace multi_cta_search @@ -143,7 +143,7 @@ using compute_distance_to_child_nodes_kernel_func_t = IndexT* const, DistanceT* const, const std::uint32_t, - cagra_bitset); + cagra_sample_filter); template using apply_filter_kernel_func_t = void(const SourceIndexT* const, @@ -153,7 +153,7 @@ using apply_filter_kernel_func_t = void(const SourceIndexT* const, const std::uint32_t, const std::uint32_t, const std::uint32_t, - cagra_bitset); + cagra_sample_filter); } // namespace multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh new file mode 100644 index 0000000000..a09e1cda35 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh @@ -0,0 +1,91 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "cagra_bitset.cuh" + +#include +#include + +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +inline constexpr const char* cagra_udf_source_index_type_name() +{ + static_assert(std::is_same_v, + "CAGRA filter UDFs currently support SourceIndexT = uint32_t only"); + return "uint32_t"; +} + +inline std::string instantiate_cagra_sample_filter_udf(std::string const& user_source, + std::string const& function_name, + const char* source_index_type) +{ + std::ostringstream oss; + oss << R"( +using int8_t = signed char; +using uint8_t = unsigned char; +using int32_t = int; +using uint32_t = unsigned int; +using int64_t = long long; +using uint64_t = unsigned long long; +using source_index_t = )" + << source_index_type << R"(; + +namespace cuvs::neighbors::detail { +template +extern __device__ bool sample_filter(uint32_t query_id, SourceIndexT node_id, void* filter_data); +} // namespace cuvs::neighbors::detail + +)"; + oss << user_source << R"( + +namespace cuvs::neighbors::detail { + +template <> +__device__ bool sample_filter(uint32_t query_id, + source_index_t node_id, + void* filter_data) +{ + return )" + << function_name << R"((query_id, node_id, filter_data); +} + +} // namespace cuvs::neighbors::detail +)"; + return oss.str(); +} + +template +std::unique_ptr make_cagra_sample_filter_udf_fragment( + const SampleFilterT& sample_filter) +{ + const auto* udf = get_cagra_udf_filter(sample_filter); + if (udf == nullptr) { return nullptr; } + + RAFT_EXPECTS(!udf->source.empty(), "CAGRA filter UDF source must not be empty"); + RAFT_EXPECTS(!udf->function_name.empty(), "CAGRA filter UDF function name must not be empty"); + + auto code = instantiate_cagra_sample_filter_udf( + udf->source, udf->function_name, cagra_udf_source_index_type_name()); + std::string key = "cagra_sample_filter_udf:"; + key += cagra_udf_source_index_type_name(); + key += ":"; + key += udf->cache_key.empty() ? udf->source : udf->cache_key; + key += ":"; + key += udf->function_name; + key += ":"; + key += code; + return nvrtc_compiler().compile(key, code); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh index 492195f359..446b15cc3e 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh @@ -52,7 +52,7 @@ __device__ void search_kernel_jit( uint32_t* const num_executed_iterations, /* stats */ const IndexT graph_size, const uint32_t query_id_offset, // Offset to add to query_id when calling filter - cagra_bitset bitset) + cagra_sample_filter filter_payload) { using DATA_T = DataT; using INDEX_T = IndexT; @@ -261,7 +261,7 @@ __device__ void search_kernel_jit( const auto parent_id = result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; if (!sample_filter(query_id + query_id_offset, to_source_index(parent_id), - bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + get_cagra_sample_filter_data(filter_payload))) { // If the parent must not be in the resulting top-k list, remove from the parent list result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); result_indices_buffer[parent_indices_buffer[p]] = invalid_index; @@ -280,7 +280,7 @@ __device__ void search_kernel_jit( index &= ~index_msb_1_mask; if (!sample_filter(query_id + query_id_offset, to_source_index(index), - bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + get_cagra_sample_filter_data(filter_payload))) { result_indices_buffer[i] = invalid_index; result_distances_buffer[i] = utils::get_max_value(); } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in index 9acd73687b..1d22bd5449 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in @@ -14,7 +14,7 @@ using distance_t = @distance_type@; using source_index_t = @source_index_type@; using dataset_desc_base = cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; -using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; +using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace @@ -42,7 +42,7 @@ extern "C" __global__ __launch_bounds__(1024, 1) void search_multi_cta( std::uint32_t* const num_executed_iterations, const index_t graph_size, const std::uint32_t query_id_offset, - cagra_bitset_t bitset) + cagra_sample_filter_t filter_payload) { search_kernel_jit(result_indices_ptr, result_distances_ptr, @@ -65,7 +65,7 @@ extern "C" __global__ __launch_bounds__(1024, 1) void search_multi_cta( num_executed_iterations, graph_size, query_id_offset, - bitset); + filter_payload); } static_assert( diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh index a7ab078343..b580db0b80 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh @@ -105,7 +105,7 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( IndexT* const result_indices_ptr, // [num_queries, ldd] DistanceT* const result_distances_ptr, // [num_queries, ldd] const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree - cagra_bitset bitset) + cagra_sample_filter filter_payload) { using INDEX_T = IndexT; using DISTANCE_T = DistanceT; @@ -115,13 +115,14 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( const auto team_size = 1u << team_size_bits; const uint32_t ldb = hashmap::get_size(hash_bitlen); - const auto tid = threadIdx.x + blockDim.x * blockIdx.x; - const auto global_team_id = tid >> team_size_bits; - const auto query_id = blockIdx.y; + const auto tid = threadIdx.x + blockDim.x * blockIdx.x; + const auto global_team_id = tid >> team_size_bits; + const auto local_query_id = blockIdx.y; + const auto filter_query_id = filter_payload.query_id_offset + local_query_id; extern __shared__ uint8_t smem[]; auto smem_desc = - setup_workspace(dataset_desc, smem, query_ptr, query_id); + setup_workspace(dataset_desc, smem, query_ptr, local_query_id); __syncthreads(); if (global_team_id >= search_width * graph_degree) { return; } @@ -132,10 +133,12 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( if (parent_list_index == utils::get_max_value()) { return; } constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const auto raw_parent_index = parent_candidates_ptr[parent_list_index + (lds * query_id)]; + const auto raw_parent_index = + parent_candidates_ptr[parent_list_index + (lds * local_query_id)]; if (raw_parent_index == utils::get_max_value()) { - result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + result_distances_ptr[ldd * local_query_id + global_team_id] = + utils::get_max_value(); return; } const auto parent_index = raw_parent_index & ~index_msb_1_mask; @@ -153,24 +156,25 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( if (compute_distance_flag) { if ((threadIdx.x & (team_size - 1)) == 0) { - result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; - result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2; + result_indices_ptr[ldd * local_query_id + global_team_id] = child_id; + result_distances_ptr[ldd * local_query_id + global_team_id] = norm2; } } else { if ((threadIdx.x & (team_size - 1)) == 0) { - result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + result_distances_ptr[ldd * local_query_id + global_team_id] = + utils::get_max_value(); } } - if (bitset.bitset_ptr != nullptr) { - const SourceIndexT node_id = source_indices_ptr == nullptr - ? static_cast(parent_index) - : static_cast(source_indices_ptr[parent_index]); - if (!sample_filter(query_id, node_id, &bitset)) { - parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); - parent_distance_ptr[parent_list_index + (lds * query_id)] = - utils::get_max_value(); - } + const SourceIndexT node_id = source_indices_ptr == nullptr + ? static_cast(parent_index) + : static_cast(source_indices_ptr[parent_index]); + if (!sample_filter( + filter_query_id, node_id, get_cagra_sample_filter_data(filter_payload))) { + parent_candidates_ptr[parent_list_index + (lds * local_query_id)] = + utils::get_max_value(); + parent_distance_ptr[parent_list_index + (lds * local_query_id)] = + utils::get_max_value(); } } @@ -185,7 +189,7 @@ __device__ void apply_filter_kernel_jit( const std::uint32_t result_buffer_size, const std::uint32_t num_queries, const std::uint32_t query_id_offset, - cagra_bitset bitset) + cagra_sample_filter filter_payload) { constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; const auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -195,14 +199,14 @@ __device__ void apply_filter_kernel_jit( const auto index = i + j * lds; if (result_indices_ptr[index] != ~index_msb_1_mask) { - // Use extern sample_filter function with 3 params: query_id, node_id, filter_data - // Third argument is &bitset (layout matches bitset_filter_data_t) or nullptr for none_filter + // Use extern sample_filter function with 3 params: query_id, node_id, filter_data. + // The payload maps built-in bitset filters and UDF context pointers to the linked ABI. SourceIndexT node_id = source_indices_ptr == nullptr ? static_cast(result_indices_ptr[index]) : source_indices_ptr[result_indices_ptr[index]]; if (!sample_filter( - query_id_offset + j, node_id, bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + query_id_offset + j, node_id, get_cagra_sample_filter_data(filter_payload))) { result_indices_ptr[index] = utils::get_max_value(); result_distances_ptr[index] = utils::get_max_value(); } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh index 282490559a..69098c75e5 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh @@ -78,7 +78,7 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core( const std::uint32_t query_id, const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter const dataset_descriptor_base_t* dataset_desc, - cagra_bitset bitset, + cagra_sample_filter filter_payload, const IndexT graph_size = 0) // Original number of bits { using LOAD_T = device::LOAD_128BIT_T; @@ -306,7 +306,7 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core( const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; if (!sample_filter(query_id + query_id_offset, to_source_index(parent_id), - bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + get_cagra_sample_filter_data(filter_payload))) { result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); result_indices_buffer[parent_list_buffer[p]] = invalid_index; *filter_flag = 1; @@ -327,7 +327,7 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core( if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id + query_id_offset, to_source_index(node_id), - bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + get_cagra_sample_filter_data(filter_payload))) { result_distances_buffer[i] = utils::get_max_value(); result_indices_buffer[i] = invalid_index; } @@ -483,7 +483,7 @@ __device__ void search_kernel_jit( const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter const dataset_descriptor_base_t* dataset_desc, const IndexT graph_size, - cagra_bitset bitset) + cagra_sample_filter filter_payload) { const auto query_id = blockIdx.y; search_core* dataset_desc, - cagra_bitset bitset) + cagra_sample_filter filter_payload) { using job_desc_type = job_desc_t>; __shared__ typename job_desc_type::input_t job_descriptor; @@ -625,7 +625,7 @@ __device__ void search_single_cta_p_impl( query_id, query_id_offset, dataset_desc, - bitset); + filter_payload); // make sure all writes are visible even for the host // (e.g. when result buffers are in pinned memory) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in index 26b363a90a..40e7b15a6f 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in @@ -17,7 +17,7 @@ using distance_t = @distance_type@; using source_index_t = @source_index_type@; using dataset_desc_base = cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; -using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; +using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace @@ -49,7 +49,7 @@ extern "C" __global__ __launch_bounds__(1024, 1) void search_single_cta( const std::uint32_t query_id_offset, const dataset_desc_base* dataset_desc, const index_t graph_size, - cagra_bitset_t bitset) + cagra_sample_filter_t filter_payload) { single_cta_search::search_kernel_jit; using job_descriptor_batch = scta_jit::job_desc_t>; -using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; +using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace @@ -51,7 +51,7 @@ extern "C" __global__ __launch_bounds__(1024, 1) void search_single_cta_p( const std::uint32_t small_hash_reset_interval, const std::uint32_t query_id_offset, const dataset_desc_base* dataset_desc, - cagra_bitset_t bitset) + cagra_sample_filter_t filter_payload) { search_single_cta_p_impl>; +using udf_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::udf_filter>; } // namespace @@ -20,5 +22,6 @@ instantiate_kernel_selection(data_t, float, cuvs::neighbors::filtering::none_sample_filter); instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t); +instantiate_kernel_selection(data_t, uint32_t, float, udf_filter_t); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh index 663f8a4559..82bdd8ac93 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh @@ -64,7 +64,8 @@ void select_and_run(const dataset_descriptor_host& dat IndexT, DistanceT, SourceIndexT, - sample_filter_jit_tag_t>(dataset_desc); + sample_filter_jit_tag_t>( + dataset_desc, make_cagra_sample_filter_udf_fragment(sample_filter)); if (!launcher) { RAFT_FAIL("Failed to get JIT launcher"); } @@ -142,7 +143,7 @@ void select_and_run(const dataset_descriptor_host& dat num_executed_iterations, static_cast(graph.extent(0)), query_id_offset, - bf.bitset); + bf); }; cuvs::neighbors::detail::safely_launch_kernel_with_smem_size< multi_cta_search::search_multi_cta_kernel_func_t>( diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index f6dfcb8aee..5016c4b962 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -519,7 +519,9 @@ struct search DISTANCE_T, SourceIndexT, sample_filter_jit_tag_t>( - dataset_desc, "compute_distance_to_child_nodes"); + dataset_desc, + "compute_distance_to_child_nodes", + make_cagra_sample_filter_udf_fragment(sample_filter)); unsigned iter = 0; while (1) { @@ -615,7 +617,8 @@ struct search DISTANCE_T, SourceIndexT, sample_filter_jit_tag_t>( - dataset_desc); + dataset_desc, + make_cagra_sample_filter_udf_fragment(sample_filter)); apply_filter_jit( source_indices_ptr, diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh index d0e3db2c99..4d7b087b80 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh @@ -142,7 +142,7 @@ void compute_distance_to_child_nodes_jit( result_indices_ptr, result_distances_ptr, ldd, - bf.bitset); + bf); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -160,9 +160,8 @@ void apply_filter_jit(const SourceIndexT* source_indices_ptr, cudaStream_t cuda_stream, std::shared_ptr const& launcher) { - // Note: query_id for the linked filter is the function's `query_id_offset` + query index, not - // the wrapper's offset; we only need bitset pointers (same as other JIT launchers). const auto bf = extract_cagra_sample_filter(sample_filter); + const auto effective_query_id_offset = query_id_offset + bf.query_id_offset; const std::uint32_t block_size = 256; const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); @@ -181,8 +180,8 @@ void apply_filter_jit(const SourceIndexT* source_indices_ptr, lds, result_buffer_size, num_queries, - query_id_offset, - bf.bitset); + effective_query_id_offset, + bf); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in index 85342e7093..4616a9652b 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in @@ -11,6 +11,8 @@ namespace { using data_t = @data_type@; using bitset_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< cuvs::neighbors::filtering::bitset_filter>; +using udf_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::udf_filter>; } // namespace @@ -20,5 +22,6 @@ instantiate_kernel_selection(data_t, float, cuvs::neighbors::filtering::none_sample_filter); instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t); +instantiate_kernel_selection(data_t, uint32_t, float, udf_filter_t); } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh index 96db9a743b..6a3118e3ae 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -45,6 +45,7 @@ #include #include #include +#include #include #include #include @@ -57,6 +58,40 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search { +inline std::uint64_t cagra_hash_combine(std::uint64_t seed, std::uint64_t value) +{ + return seed ^ (value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2)); +} + +template +std::uint64_t cagra_udf_source_hash(const SampleFilterT& sample_filter) +{ + if (const auto* udf = get_cagra_udf_filter(sample_filter); udf != nullptr) { + std::uint64_t seed = 0; + seed = cagra_hash_combine(seed, std::hash{}(udf->source)); + seed = cagra_hash_combine(seed, std::hash{}(udf->cache_key)); + seed = cagra_hash_combine(seed, std::hash{}(udf->function_name)); + return seed; + } + return 0; +} + +template +std::uint64_t cagra_sample_filter_hash(const SampleFilterT& sample_filter) +{ + const auto payload = extract_cagra_sample_filter(sample_filter); + std::uint64_t seed = static_cast(payload.filter_kind); + seed = cagra_hash_combine( + seed, static_cast(reinterpret_cast(payload.bitset.bitset_ptr))); + seed = cagra_hash_combine(seed, static_cast(payload.bitset.bitset_len)); + seed = cagra_hash_combine(seed, static_cast(payload.bitset.original_nbits)); + seed = cagra_hash_combine( + seed, static_cast(reinterpret_cast(payload.filter_data))); + seed = cagra_hash_combine(seed, static_cast(payload.query_id_offset)); + seed = cagra_hash_combine(seed, cagra_udf_source_hash(sample_filter)); + return seed; +} + // Persistent queues / runner (host). worker_handle_t, job_desc_t, kCacheLineBytes, k* job limits: // `jit_lto_kernels/search_single_cta_device_helpers.cuh` via `kernel_def.hpp`. @@ -453,7 +488,7 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn rmm::device_uvector hashmap; std::atomic> last_touch; uint64_t param_hash; - cagra_bitset bitset; + cagra_sample_filter filter_payload; static inline auto calculate_parameter_hash( std::reference_wrapper> dataset_desc, @@ -481,7 +516,7 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn bool bitonic_sort_and_merge_multi_warps) -> uint64_t { (void)small_hash_bitlen; - (void)sample_filter; + const uint64_t filter_key = cagra_sample_filter_hash(sample_filter); const uint64_t bitonic_key = (topk_by_bitonic_sort ? 1ULL : 0ULL) ^ (bitonic_sort_and_merge_multi_warps ? 2ULL : 0ULL); return uint64_t(graph.data_handle()) ^ uint64_t(source_indices_ptr) ^ @@ -489,13 +524,14 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn hash_bitlen ^ small_hash_reset_interval ^ num_random_samplings ^ rand_xor_mask ^ num_seeds ^ itopk_size ^ search_width ^ min_iterations ^ max_iterations ^ uint64_t(persistent_lifetime * 1000) ^ uint64_t(persistent_device_usage * 1000) ^ - bitonic_key; + bitonic_key ^ filter_key; } static auto make_persistent_launcher( const dataset_descriptor_host& dataset_desc, bool topk_by_bitonic_sort, - bool bitonic_sort_and_merge_multi_warps) -> std::shared_ptr + bool bitonic_sort_and_merge_multi_warps, + SampleFilterT sample_filter) -> std::shared_ptr { auto launcher = make_cagra_single_cta_jit_launcher(sample_filter)); if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA persistent search kernel"); } return launcher; } @@ -535,8 +572,10 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn bool topk_by_bitonic_sort, bool bitonic_sort_and_merge_multi_warps) : persistent_runner_base_t{persistent_lifetime}, - launcher{make_persistent_launcher( - dataset_desc.get(), topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps)}, + launcher{make_persistent_launcher(dataset_desc.get(), + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + sample_filter)}, block_size{block_size}, worker_handles(0, stream, worker_handles_mr), job_descriptors(kMaxJobsNum, stream, job_descriptor_mr), @@ -568,7 +607,7 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn bitonic_sort_and_merge_multi_warps)) { const auto bf = extract_cagra_sample_filter(sample_filter); - this->bitset = bf.bitset; + this->filter_payload = bf; const uint32_t query_id_offset = bf.query_id_offset; // set kernel launch parameters @@ -656,7 +695,7 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn small_hash_reset_interval_u32, // Cast size_t to uint32_t query_id_offset, // Offset to add to query_id when calling filter dev_desc, - bitset); + filter_payload); last_touch.store(std::chrono::system_clock::now(), std::memory_order_relaxed); } @@ -757,9 +796,9 @@ void select_and_run( const SourceIndexT* source_indices_ptr = source_indices.has_value() ? source_indices->data_handle() : nullptr; - const auto bf = extract_cagra_sample_filter(sample_filter); - const cagra_bitset bitset = bf.bitset; - const uint32_t query_id_offset = bf.query_id_offset; + const auto bf = extract_cagra_sample_filter(sample_filter); + const auto filter_payload = bf; + const uint32_t query_id_offset = bf.query_id_offset; // Use common logic to compute launch config auto config = compute_launch_config(num_itopk_candidates, ps.itopk_size, block_size); @@ -813,7 +852,8 @@ void select_and_run( dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, - false /* persistent */); + false /* persistent */, + make_cagra_sample_filter_udf_fragment(sample_filter)); if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA search kernel"); } // Get the device descriptor pointer - dev_ptr() initializes it if needed @@ -871,7 +911,7 @@ void select_and_run( query_id_offset, // Offset to add to query_id when calling filter dev_desc, static_cast(graph.extent(0)), - bitset); + filter_payload); }; cuvs::neighbors::detail::safely_launch_kernel_with_smem_size< diff --git a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp index a16c810409..d6de21644f 100644 --- a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp +++ b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp @@ -100,19 +100,22 @@ struct sample_filter_jit_tag { using namespace cuvs::neighbors::filtering; if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_filter_none{}; + } else if constexpr (is_udf_filter::value) { + return cuvs::neighbors::detail::tag_filter_udf{}; } else if constexpr (requires { std::declval().filter; }) { using InnerFilter = decltype(std::declval().filter); - if constexpr (is_bitset_filter::value || - std::is_same_v> || - std::is_same_v>) { + if constexpr (is_bitset_filter>::value || + std::is_same_v, bitset_filter> || + std::is_same_v, bitset_filter>) { return cuvs::neighbors::detail::tag_filter_bitset{}; + } else if constexpr (is_udf_filter>::value) { + return cuvs::neighbors::detail::tag_filter_udf{}; } else { static_assert( cagra_jit_sample_filter_tag_type_always_false, "CAGRA JIT: sample_filter_jit_tag does not know how to link this filter. " - "CagraSampleFilterWithQueryIdOffset requires Inner of type " - "bitset_filter (see cagra_bitset.cuh is_bitset_filter and sample_filter_utils.cuh). " + "CagraSampleFilterWithQueryIdOffset requires Inner to be a supported " + "built-in filter or udf_filter (see cagra_bitset.cuh and sample_filter_utils.cuh). " "For a new filter kind, add a sample_filter_jit_tag branch. " "(SAMPLE_FILTER_T in error; check InnerFilter in compiler output.)"); } @@ -120,9 +123,9 @@ struct sample_filter_jit_tag { static_assert( cagra_jit_sample_filter_tag_type_always_false, "CAGRA JIT: sample_filter_jit_tag: SAMPLE_FILTER_T must be cuvs::neighbors::filtering::" - "none_sample_filter, or " - "cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<" - "bitset_filter>. Unknown wrapper type. " + "none_sample_filter, udf_filter, or " + "cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset. " + "Unknown wrapper type. " "(SAMPLE_FILTER_T in error; add a branch in sample_filter_jit_tag.)"); } } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 48a444fa18..91c4f62cbb 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -180,6 +180,13 @@ ConfigureTest( PERCENT 100 ) +ConfigureTest( + NAME NEIGHBORS_ANN_CAGRA_FILTER_UDF_TEST + PATH neighbors/ann_cagra/test_filter_udf.cu + GPUS 1 + PERCENT 100 +) + ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_HELPERS_TEST PATH neighbors/ann_cagra/test_optimize_uint32_t.cu neighbors/ann_cagra/test_batch_load_iterator.cu diff --git a/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu b/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu new file mode 100644 index 0000000000..10da68dbdd --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu @@ -0,0 +1,299 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra { +namespace { + +constexpr int64_t n_rows = 768; +constexpr int64_t n_dim = 16; +constexpr int64_t n_queries = 6; +constexpr int64_t k = 8; +constexpr int64_t threshold = 192; +constexpr int64_t high_filtering_threshold = 704; +constexpr float high_filtering_rate = + static_cast(high_filtering_threshold) / static_cast(n_rows); + +struct tenant_filter_context { + const uint32_t* row_tenants; + const uint32_t* query_tenants; +}; + +struct cagra_search_result { + std::vector neighbors; + std::vector distances; +}; + +std::string accept_all_udf_source() +{ + return R"cpp( + __device__ bool cuvs_filter_udf(uint32_t, source_index_t, void*) { return true; } + )cpp"; +} + +std::string threshold_udf_source() +{ + return R"cpp( + __device__ bool cuvs_filter_udf(uint32_t, source_index_t source_id, void*) + { + return source_id >= 192; + } + )cpp"; +} + +std::string reject_all_udf_source() +{ + return R"cpp( + __device__ bool cuvs_filter_udf(uint32_t, source_index_t, void*) { return false; } + )cpp"; +} + +std::string high_filtering_rate_udf_source() +{ + return R"cpp( + __device__ bool cuvs_filter_udf(uint32_t, source_index_t source_id, void*) + { + return source_id >= 704; + } + )cpp"; +} + +std::string tenant_udf_source() +{ + return R"cpp( + struct tenant_filter_context { + const uint32_t* row_tenants; + const uint32_t* query_tenants; + }; + + __device__ bool cuvs_filter_udf(uint32_t query_id, source_index_t source_id, void* filter_data) + { + auto* ctx = static_cast(filter_data); + return ctx->row_tenants[source_id] == ctx->query_tenants[query_id]; + } + )cpp"; +} + +void expect_same_results(cagra_search_result const& expected, cagra_search_result const& actual) +{ + ASSERT_EQ(expected.neighbors, actual.neighbors); + ASSERT_EQ(expected.distances.size(), actual.distances.size()); + for (size_t i = 0; i < expected.distances.size(); ++i) { + EXPECT_FLOAT_EQ(expected.distances[i], actual.distances[i]); + } +} + +class CagraUdfFilterTest : public ::testing::TestWithParam { + protected: + void SetUp() override + { + dataset.emplace(raft::make_device_matrix(res, n_rows, n_dim)); + queries.emplace(raft::make_device_matrix(res, n_queries, n_dim)); + + raft::random::RngState rng(1234ULL); + raft::random::uniform(res, rng, dataset->data_handle(), dataset->size(), -1.0f, 1.0f); + raft::random::uniform(res, rng, queries->data_handle(), queries->size(), -1.0f, 1.0f); + + cagra::index_params index_params; + index_params.metric = cuvs::distance::DistanceType::L2Expanded; + index_params.graph_degree = 32; + index_params.intermediate_graph_degree = 64; + index_params.graph_build_params = + cagra::graph_build_params::nn_descent_params(index_params.intermediate_graph_degree); + + index.emplace(cagra::build(res, index_params, raft::make_const_mdspan(dataset->view()))); + raft::resource::sync_stream(res); + } + + void TearDown() override + { + index.reset(); + queries.reset(); + dataset.reset(); + raft::resource::sync_stream(res); + } + + cagra_search_result search(cuvs::neighbors::filtering::base_filter const& filter, + float filtering_rate = -1.0f) + { + auto neighbors = raft::make_device_matrix(res, n_queries, k); + auto distances = raft::make_device_matrix(res, n_queries, k); + + cagra::search_params search_params; + search_params.algo = GetParam(); + search_params.itopk_size = 64; + search_params.max_queries = 2; + search_params.thread_block_size = 256; + search_params.filtering_rate = filtering_rate; + + cagra::search(res, + search_params, + *index, + raft::make_const_mdspan(queries->view()), + neighbors.view(), + distances.view(), + filter); + + auto stream = raft::resource::get_cuda_stream(res); + cagra_search_result result{std::vector(n_queries * k), + std::vector(n_queries * k)}; + raft::copy(result.neighbors.data(), neighbors.data_handle(), result.neighbors.size(), stream); + raft::copy(result.distances.data(), distances.data_handle(), result.distances.size(), stream); + raft::resource::sync_stream(res); + return result; + } + + raft::resources res; + std::optional> dataset = std::nullopt; + std::optional> queries = std::nullopt; + std::optional> index = std::nullopt; +}; + +TEST_P(CagraUdfFilterTest, AcceptAllMatchesNoFilter) +{ + cuvs::neighbors::filtering::none_sample_filter no_filter; + auto expected = search(no_filter, 0.0f); + + cuvs::neighbors::filtering::udf_filter udf_filter( + accept_all_udf_source(), nullptr, 0.0f, "cagra-accept-all-filter"); + auto actual = search(udf_filter); + + expect_same_results(expected, actual); +} + +TEST_P(CagraUdfFilterTest, RejectAllReturnsNoValidNeighbors) +{ + cuvs::neighbors::filtering::udf_filter udf_filter( + reject_all_udf_source(), nullptr, 0.999f, "cagra-reject-all-filter"); + auto result = search(udf_filter); + + // CAGRA algorithms do not all normalize empty-result sentinels the same way. Single-CTA + // clears the internal high-bit marker before writing output, so 0xffffffff can become + // 0x7fffffff; other paths may leave 0xffffffff. Both are invalid row ids for this index. + for (auto source_id : result.neighbors) { + EXPECT_GE(source_id, static_cast(n_rows)); + } +} + +TEST_P(CagraUdfFilterTest, HighFilteringRateReturnsOnlyValidNeighbors) +{ + cuvs::neighbors::filtering::udf_filter udf_filter(high_filtering_rate_udf_source(), + nullptr, + high_filtering_rate, + "cagra-high-filtering-rate-filter"); + auto result = search(udf_filter); + + for (auto source_id : result.neighbors) { + if (source_id < static_cast(n_rows)) { + EXPECT_GE(source_id, static_cast(high_filtering_threshold)); + } + } +} + +TEST_P(CagraUdfFilterTest, RepeatedUdfSearchWithSameCacheKeyMatches) +{ + cuvs::neighbors::filtering::udf_filter udf_filter( + accept_all_udf_source(), nullptr, 0.0f, "cagra-repeat-cache-filter"); + + auto first = search(udf_filter); + auto second = search(udf_filter); + + expect_same_results(first, second); +} + +TEST_P(CagraUdfFilterTest, InvalidSourceThrows) +{ + cuvs::neighbors::filtering::udf_filter udf_filter( + "this is not valid cuda source", nullptr, 0.0f, "cagra-invalid-filter"); + + EXPECT_THROW(search(udf_filter), std::exception); +} + +TEST_P(CagraUdfFilterTest, ThresholdMatchesEquivalentBitset) +{ + auto removed_indices = raft::make_device_vector(res, threshold); + thrust::sequence(raft::resource::get_thrust_policy(res), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + threshold)); + raft::resource::sync_stream(res); + + cuvs::core::bitset removed_indices_bitset( + res, removed_indices.view(), n_rows); + cuvs::neighbors::filtering::bitset_filter bitset_filter(removed_indices_bitset.view()); + + float const filtering_rate = static_cast(threshold) / static_cast(n_rows); + auto expected = search(bitset_filter, filtering_rate); + + cuvs::neighbors::filtering::udf_filter udf_filter( + threshold_udf_source(), nullptr, filtering_rate, "cagra-threshold-filter"); + auto actual = search(udf_filter, filtering_rate); + + expect_same_results(expected, actual); +} + +TEST_P(CagraUdfFilterTest, TenantContextHonorsQuerySpecificMetadata) +{ + std::vector host_row_tenants(n_rows); + std::vector host_query_tenants(n_queries); + for (int64_t i = 0; i < n_rows; ++i) { + host_row_tenants[static_cast(i)] = static_cast((i / 5) % 3); + } + for (int64_t q = 0; q < n_queries; ++q) { + host_query_tenants[static_cast(q)] = static_cast(q % 3); + } + + auto row_tenants = raft::make_device_vector(res, n_rows); + auto query_tenants = raft::make_device_vector(res, n_queries); + auto context = raft::make_device_vector(res, 1); + + auto stream = raft::resource::get_cuda_stream(res); + raft::copy(row_tenants.data_handle(), host_row_tenants.data(), host_row_tenants.size(), stream); + raft::copy( + query_tenants.data_handle(), host_query_tenants.data(), host_query_tenants.size(), stream); + + tenant_filter_context host_context{row_tenants.data_handle(), query_tenants.data_handle()}; + raft::copy(context.data_handle(), &host_context, 1, stream); + raft::resource::sync_stream(res); + + cuvs::neighbors::filtering::udf_filter udf_filter( + tenant_udf_source(), context.data_handle(), 2.0f / 3.0f, "cagra-tenant-filter"); + auto result = search(udf_filter); + + for (int64_t q = 0; q < n_queries; ++q) { + auto query_tenant = host_query_tenants[static_cast(q)]; + for (int64_t i = 0; i < k; ++i) { + auto source_id = result.neighbors[static_cast(q * k + i)]; + ASSERT_LT(source_id, static_cast(n_rows)); + EXPECT_EQ(host_row_tenants[source_id], query_tenant); + } + } +} + +INSTANTIATE_TEST_CASE_P(CagraUdfFilters, + CagraUdfFilterTest, + ::testing::Values(cagra::search_algo::SINGLE_CTA, + cagra::search_algo::MULTI_CTA, + cagra::search_algo::MULTI_KERNEL)); + +} // namespace +} // namespace cuvs::neighbors::cagra diff --git a/docs/source/filtering.rst b/docs/source/filtering.rst index cb168f94c8..bf6884042d 100644 --- a/docs/source/filtering.rst +++ b/docs/source/filtering.rst @@ -23,6 +23,25 @@ Bitmap A bitmap is based on the same principle as a bitset, but in two dimensions. This allows users to provide a different bitset for each query being searched. Check out RAFT's `bitmap API documentation `. +CAGRA filter UDF +================ + +CAGRA also supports a low-level JIT-LTO filter UDF for predicates that are more naturally expressed as CUDA device code. The UDF source must define a device function that returns `true` when a source vector is allowed and `false` when it should be rejected: + +.. code-block:: c++ + + __device__ bool cuvs_filter_udf(uint32_t query_id, + source_index_t source_id, + void* filter_data); + +`source_index_t` is currently `uint32_t` for CAGRA. `filter_data` is an opaque pointer passed through to the device predicate; if the UDF dereferences it, the pointer and any nested pointers must refer to device-accessible memory and remain valid for the duration of the search. + +When `cagra::search_params::filtering_rate` is negative, CAGRA uses `filtering::udf_filter::filtering_rate`. If both are negative, CAGRA assumes `0.0` because it cannot infer UDF selectivity from the source string. Providing a reasonable filtering rate helps CAGRA size its internal search work for selective filters. + +Filter UDFs are candidate-validity predicates only. They receive logical query and source identifiers plus the caller-provided context pointer; they do not expose CAGRA graph traversal state, IVF probing decisions, PQ/VPQ encoded data, or other internal index layouts. cuVS still owns traversal, distance computation, and result selection. + +Filtered CAGRA search is still approximate ANN search. The UDF prevents rejected candidates from appearing in the returned results, but it does not guarantee exact brute-force filtered nearest-neighbor semantics. For highly selective predicates, provide a realistic `filtering_rate` so CAGRA can size its internal search work appropriately. + Examples ======== @@ -68,6 +87,126 @@ Using a Bitset filter on a CAGRA index bitset_filter); +Using a CAGRA filter UDF +------------------------- + +.. code-block:: c++ + + struct tenant_filter_context { + const uint32_t* row_tenants; + const uint32_t* query_tenants; + }; + + std::string source = R"cpp( + struct tenant_filter_context { + const uint32_t* row_tenants; + const uint32_t* query_tenants; + }; + + __device__ bool cuvs_filter_udf(uint32_t query_id, + source_index_t source_id, + void* filter_data) + { + auto* ctx = static_cast(filter_data); + return ctx->row_tenants[source_id] == ctx->query_tenants[query_id]; + } + )cpp"; + + tenant_filter_context host_ctx{row_tenants_device, query_tenants_device}; + tenant_filter_context* ctx_device = copy_to_device(host_ctx); + + auto filter = cuvs::neighbors::filtering::udf_filter( + source, ctx_device, 0.75f, "tenant-filter-v1"); + + cagra::search(res, search_params, index, queries, neighbors, distances, filter); + + +Choosing among metadata UDF predicates +--------------------------------------- + +A single UDF source can contain several device predicates over the same context. Select the predicate to link by passing its name as the final `udf_filter` constructor argument. This is useful for vector database metadata filters such as tenant isolation, time ranges, language filters, and ACL checks. + +.. code-block:: c++ + + struct metadata_filter_context { + const uint32_t* row_tenant_ids; + const int64_t* row_timestamps; + const uint32_t* row_language_ids; + const uint64_t* row_acl_masks; + + const uint32_t* query_tenant_ids; + const int64_t* query_min_timestamps; + const uint64_t* query_allowed_language_masks; + const uint64_t* query_permission_masks; + }; + + std::string metadata_source = R"cpp( + struct metadata_filter_context { + const uint32_t* row_tenant_ids; + const int64_t* row_timestamps; + const uint32_t* row_language_ids; + const uint64_t* row_acl_masks; + + const uint32_t* query_tenant_ids; + const int64_t* query_min_timestamps; + const uint64_t* query_allowed_language_masks; + const uint64_t* query_permission_masks; + }; + + __device__ bool tenant_filter(uint32_t query_id, + source_index_t source_id, + void* filter_data) + { + auto* ctx = static_cast(filter_data); + return ctx->row_tenant_ids[source_id] == ctx->query_tenant_ids[query_id]; + } + + __device__ bool timestamp_filter(uint32_t query_id, + source_index_t source_id, + void* filter_data) + { + auto* ctx = static_cast(filter_data); + return ctx->row_timestamps[source_id] >= ctx->query_min_timestamps[query_id]; + } + + __device__ bool language_acl_filter(uint32_t query_id, + source_index_t source_id, + void* filter_data) + { + auto* ctx = static_cast(filter_data); + const auto language_bit = uint64_t{1} << ctx->row_language_ids[source_id]; + const bool language_ok = + (ctx->query_allowed_language_masks[query_id] & language_bit) != 0; + const bool acl_ok = + (ctx->row_acl_masks[source_id] & ctx->query_permission_masks[query_id]) != 0; + return language_ok && acl_ok; + } + )cpp"; + + metadata_filter_context host_ctx{row_tenant_ids_device, + row_timestamps_device, + row_language_ids_device, + row_acl_masks_device, + query_tenant_ids_device, + query_min_timestamps_device, + query_allowed_language_masks_device, + query_permission_masks_device}; + metadata_filter_context* ctx_device = copy_to_device(host_ctx); + + auto tenant_filter = cuvs::neighbors::filtering::udf_filter( + metadata_source, ctx_device, 0.80f, "tenant-filter-v1", "tenant_filter"); + + auto timestamp_filter = cuvs::neighbors::filtering::udf_filter( + metadata_source, ctx_device, 0.25f, "timestamp-filter-v1", "timestamp_filter"); + + auto language_acl_filter = cuvs::neighbors::filtering::udf_filter( + metadata_source, ctx_device, 0.95f, "language-acl-filter-v1", "language_acl_filter"); + + cagra::search(res, search_params, index, queries, neighbors, distances, tenant_filter); + cagra::search(res, search_params, index, queries, neighbors, distances, timestamp_filter); + cagra::search(res, search_params, index, queries, neighbors, distances, language_acl_filter); + + Using a Bitmap filter on a Brute-force index -------------------------------------------- diff --git a/docs/source/neighbors/cagra.rst b/docs/source/neighbors/cagra.rst index 471f3a915a..092a96007c 100644 --- a/docs/source/neighbors/cagra.rst +++ b/docs/source/neighbors/cagra.rst @@ -28,6 +28,10 @@ CAGRA supports filtered search and has improved multi-CTA algorithm in branch-25 To obtain an appropriate recall in filtered search, it is necessary to set search parameters according to the filtering rate, but since it is difficult for users to do this, CAGRA automatically adjusts `itopk_size` internally according to the filtering rate on a heuristic basis. If you want to disable this automatic adjustment, set `filtering_rate`, one of the search parameters, to `0.0`, and `itopk_size` will not be adjusted automatically. +CAGRA accepts bitset filters and JIT-LTO filter UDFs. A `filtering::udf_filter` is useful when the predicate needs device metadata such as per-row tenants, ACLs, or query-specific attributes. If `search_params::filtering_rate` is negative, CAGRA uses `udf_filter::filtering_rate`; if both are negative, it assumes `0.0` because the selectivity cannot be inferred from arbitrary CUDA source. + +A filter UDF only decides whether a logical source id is valid for a logical query id. It cannot control CAGRA graph traversal, access PQ/VPQ or graph internals, or change distance computation. Because CAGRA remains an approximate ANN algorithm, filtered results are not guaranteed to match exact brute-force filtered search, especially for very selective predicates without an accurate `filtering_rate`. + Configuration parameters ------------------------ diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 034b0b3d96..d63ddbdb71 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -31,6 +31,7 @@ include(../cmake/thirdparty/get_cuvs.cmake) # -------------- compile tasks ----------------- # add_executable(BRUTE_FORCE_EXAMPLE src/brute_force_bitmap.cu) add_executable(CAGRA_EXAMPLE src/cagra_example.cu) +add_executable(CAGRA_FILTER_UDF_EXAMPLE src/cagra_filter_udf_example.cu) add_executable(CAGRA_HNSW_ACE_EXAMPLE src/cagra_hnsw_ace_example.cu) add_executable(CAGRA_PERSISTENT_EXAMPLE src/cagra_persistent_example.cu) add_executable(DYNAMIC_BATCHING_EXAMPLE src/dynamic_batching_example.cu) @@ -44,6 +45,9 @@ add_executable(SCANN_EXAMPLE src/scann_example.cu) # installed in a conda environment, if one exists target_link_libraries(BRUTE_FORCE_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries(CAGRA_EXAMPLE PRIVATE cuvs::cuvs $) +target_link_libraries( + CAGRA_FILTER_UDF_EXAMPLE PRIVATE cuvs::cuvs $ +) target_link_libraries(CAGRA_HNSW_ACE_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries( CAGRA_PERSISTENT_EXAMPLE PRIVATE cuvs::cuvs $ Threads::Threads diff --git a/examples/cpp/src/cagra_filter_udf_example.cu b/examples/cpp/src/cagra_filter_udf_example.cu new file mode 100644 index 0000000000..2726db9bf3 --- /dev/null +++ b/examples/cpp/src/cagra_filter_udf_example.cu @@ -0,0 +1,251 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace { + +constexpr int64_t n_rows = 4096; +constexpr int64_t n_dim = 32; +constexpr int64_t n_queries = 4; +constexpr int64_t k = 8; + +struct metadata_filter_context { + const uint32_t* row_tenant_ids; + const int64_t* row_timestamps; + const uint32_t* row_language_ids; + const uint64_t* row_acl_masks; + + const uint32_t* query_tenant_ids; + const int64_t* query_min_timestamps; + const uint64_t* query_allowed_language_masks; + const uint64_t* query_permission_masks; +}; + +std::string metadata_udf_source() +{ + return R"cpp( + struct metadata_filter_context { + const uint32_t* row_tenant_ids; + const int64_t* row_timestamps; + const uint32_t* row_language_ids; + const uint64_t* row_acl_masks; + + const uint32_t* query_tenant_ids; + const int64_t* query_min_timestamps; + const uint64_t* query_allowed_language_masks; + const uint64_t* query_permission_masks; + }; + + __device__ bool tenant_filter(uint32_t query_id, source_index_t source_id, void* filter_data) + { + auto* ctx = static_cast(filter_data); + return ctx->row_tenant_ids[source_id] == ctx->query_tenant_ids[query_id]; + } + + __device__ bool timestamp_filter(uint32_t query_id, source_index_t source_id, void* filter_data) + { + auto* ctx = static_cast(filter_data); + return ctx->row_timestamps[source_id] >= ctx->query_min_timestamps[query_id]; + } + + __device__ bool language_acl_filter(uint32_t query_id, source_index_t source_id, void* filter_data) + { + auto* ctx = static_cast(filter_data); + const auto language_bit = uint64_t{1} << ctx->row_language_ids[source_id]; + const bool language_ok = (ctx->query_allowed_language_masks[query_id] & language_bit) != 0; + const bool acl_ok = (ctx->row_acl_masks[source_id] & ctx->query_permission_masks[query_id]) != 0; + return language_ok && acl_ok; + } + )cpp"; +} + +template +void copy_to_device(raft::device_resources const& res, DeviceVectorT& dst, HostVectorT const& src) +{ + raft::copy(dst.data_handle(), src.data(), src.size(), raft::resource::get_cuda_stream(res)); +} + +std::vector copy_neighbors_to_host( + raft::device_resources const& res, + raft::device_matrix_view neighbors) +{ + std::vector host(neighbors.size()); + raft::copy( + host.data(), neighbors.data_handle(), host.size(), raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + return host; +} + +template +void validate_and_print(char const* name, + raft::device_resources const& res, + raft::device_matrix_view neighbors, + PredicateT is_valid) +{ + auto host_neighbors = copy_neighbors_to_host(res, neighbors); + + std::cout << name << " first query neighbors:"; + for (int64_t i = 0; i < k; ++i) { + std::cout << " " << host_neighbors[static_cast(i)]; + } + std::cout << std::endl; + + for (int64_t q = 0; q < n_queries; ++q) { + for (int64_t i = 0; i < k; ++i) { + auto source_id = host_neighbors[static_cast(q * k + i)]; + if (source_id < static_cast(n_rows) && !is_valid(q, source_id)) { + std::cerr << name << " produced invalid source_id=" << source_id << " for query=" << q + << std::endl; + std::exit(1); + } + } + } +} + +} // namespace + +int main() +{ + raft::device_resources res; + + rmm::mr::pool_memory_resource pool_mr(rmm::mr::get_current_device_resource_ref(), + 1024 * 1024 * 1024ull); + rmm::mr::set_current_device_resource(pool_mr); + + auto dataset = raft::make_device_matrix(res, n_rows, n_dim); + auto queries = raft::make_device_matrix(res, n_queries, n_dim); + + raft::random::RngState rng(1234ULL); + raft::random::uniform(res, rng, dataset.data_handle(), dataset.size(), -1.0f, 1.0f); + raft::random::uniform(res, rng, queries.data_handle(), queries.size(), -1.0f, 1.0f); + + cuvs::neighbors::cagra::index_params index_params; + index_params.metric = cuvs::distance::DistanceType::L2Expanded; + index_params.graph_degree = 32; + index_params.intermediate_graph_degree = 64; + index_params.graph_build_params = cuvs::neighbors::cagra::graph_build_params::nn_descent_params( + index_params.intermediate_graph_degree); + + std::cout << "Building CAGRA index" << std::endl; + auto index = + cuvs::neighbors::cagra::build(res, index_params, raft::make_const_mdspan(dataset.view())); + + std::vector row_tenant_ids(n_rows); + std::vector row_timestamps(n_rows); + std::vector row_language_ids(n_rows); + std::vector row_acl_masks(n_rows); + for (int64_t i = 0; i < n_rows; ++i) { + row_tenant_ids[static_cast(i)] = static_cast(i % 4); + row_timestamps[static_cast(i)] = 1'700'000'000 + i; + row_language_ids[static_cast(i)] = static_cast(i % 8); + row_acl_masks[static_cast(i)] = uint64_t{1} << (i % 16); + } + + std::vector query_tenant_ids{0, 1, 2, 3}; + std::vector query_min_timestamps{ + 1'700'003'000, 1'700'002'000, 1'700'001'000, 1'700'000'500}; + std::vector query_allowed_language_masks{(uint64_t{1} << 0) | (uint64_t{1} << 1), + (uint64_t{1} << 2) | (uint64_t{1} << 3), + (uint64_t{1} << 4) | (uint64_t{1} << 5), + (uint64_t{1} << 6) | (uint64_t{1} << 7)}; + std::vector query_permission_masks{(uint64_t{1} << 0) | (uint64_t{1} << 8), + (uint64_t{1} << 2) | (uint64_t{1} << 10), + (uint64_t{1} << 4) | (uint64_t{1} << 12), + (uint64_t{1} << 6) | (uint64_t{1} << 14)}; + + auto row_tenant_ids_device = raft::make_device_vector(res, n_rows); + auto row_timestamps_device = raft::make_device_vector(res, n_rows); + auto row_language_ids_device = raft::make_device_vector(res, n_rows); + auto row_acl_masks_device = raft::make_device_vector(res, n_rows); + auto query_tenant_ids_device = raft::make_device_vector(res, n_queries); + auto query_min_timestamps_device = raft::make_device_vector(res, n_queries); + auto query_allowed_language_masks_device = + raft::make_device_vector(res, n_queries); + auto query_permission_masks_device = raft::make_device_vector(res, n_queries); + auto context_device = raft::make_device_vector(res, 1); + + copy_to_device(res, row_tenant_ids_device, row_tenant_ids); + copy_to_device(res, row_timestamps_device, row_timestamps); + copy_to_device(res, row_language_ids_device, row_language_ids); + copy_to_device(res, row_acl_masks_device, row_acl_masks); + copy_to_device(res, query_tenant_ids_device, query_tenant_ids); + copy_to_device(res, query_min_timestamps_device, query_min_timestamps); + copy_to_device(res, query_allowed_language_masks_device, query_allowed_language_masks); + copy_to_device(res, query_permission_masks_device, query_permission_masks); + + metadata_filter_context host_context{row_tenant_ids_device.data_handle(), + row_timestamps_device.data_handle(), + row_language_ids_device.data_handle(), + row_acl_masks_device.data_handle(), + query_tenant_ids_device.data_handle(), + query_min_timestamps_device.data_handle(), + query_allowed_language_masks_device.data_handle(), + query_permission_masks_device.data_handle()}; + raft::copy(context_device.data_handle(), &host_context, 1, raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + + auto neighbors = raft::make_device_matrix(res, n_queries, k); + auto distances = raft::make_device_matrix(res, n_queries, k); + + cuvs::neighbors::cagra::search_params search_params; + search_params.algo = cuvs::neighbors::cagra::search_algo::MULTI_CTA; + search_params.itopk_size = 128; + search_params.max_queries = n_queries; + search_params.thread_block_size = 256; + + auto source = metadata_udf_source(); + + auto run_filter = + [&](char const* label, char const* function_name, float filtering_rate, auto is_valid) { + auto filter = cuvs::neighbors::filtering::udf_filter( + source, context_device.data_handle(), filtering_rate, label, function_name); + cuvs::neighbors::cagra::search(res, + search_params, + index, + raft::make_const_mdspan(queries.view()), + neighbors.view(), + distances.view(), + filter); + validate_and_print(label, res, neighbors.view(), is_valid); + }; + + run_filter("tenant_filter", "tenant_filter", 0.75f, [&](int64_t query_id, uint32_t source_id) { + return row_tenant_ids[source_id] == query_tenant_ids[query_id]; + }); + + run_filter( + "timestamp_filter", "timestamp_filter", 0.50f, [&](int64_t query_id, uint32_t source_id) { + return row_timestamps[source_id] >= query_min_timestamps[query_id]; + }); + + run_filter( + "language_acl_filter", + "language_acl_filter", + 0.875f, + [&](int64_t query_id, uint32_t source_id) { + const auto language_bit = uint64_t{1} << row_language_ids[source_id]; + const bool language_ok = (query_allowed_language_masks[query_id] & language_bit) != 0; + const bool acl_ok = (row_acl_masks[source_id] & query_permission_masks[query_id]) != 0; + return language_ok && acl_ok; + }); + + std::cout << "All CAGRA filter UDF examples produced valid filtered neighbors." << std::endl; + return 0; +} From 292f7e857bd795d6341084714dffba87f72602bc Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Thu, 4 Jun 2026 15:26:42 -0500 Subject: [PATCH 02/11] ENH Hide CAGRA filter UDF cache keys from the public API --- cpp/include/cuvs/neighbors/common.hpp | 12 ++++------- .../jit_lto_kernels/sample_filter_udf.cuh | 2 -- .../search_single_cta_kernel_launcher_jit.cuh | 1 - .../neighbors/ann_cagra/test_filter_udf.cu | 20 ++++++++----------- examples/cpp/src/cagra_filter_udf_example.cu | 2 +- 5 files changed, 13 insertions(+), 24 deletions(-) diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index ad190cc7b1..8d3b04810f 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -638,22 +638,18 @@ struct udf_filter : public base_filter { void* filter_data = nullptr; /** Estimated fraction of rows rejected by the predicate, or negative if unknown. */ float filtering_rate = -1.0f; - /** Optional stable cache key for equivalent generated source. */ - std::string cache_key; /** Device function name to call from the generated CAGRA sample filter. */ std::string function_name = "cuvs_filter_udf"; udf_filter() = default; - udf_filter(std::string source, - void* filter_data = nullptr, - float filtering_rate = -1.0f, - std::string cache_key = {}, - std::string function_name = "cuvs_filter_udf") + explicit udf_filter(std::string source, + void* filter_data = nullptr, + float filtering_rate = -1.0f, + std::string function_name = "cuvs_filter_udf") : source(std::move(source)), filter_data(filter_data), filtering_rate(filtering_rate), - cache_key(std::move(cache_key)), function_name(std::move(function_name)) { } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh index a09e1cda35..d69d967287 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh @@ -80,8 +80,6 @@ std::unique_ptr make_cagra_sample_filter_udf_fragment( std::string key = "cagra_sample_filter_udf:"; key += cagra_udf_source_index_type_name(); key += ":"; - key += udf->cache_key.empty() ? udf->source : udf->cache_key; - key += ":"; key += udf->function_name; key += ":"; key += code; diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh index 6a3118e3ae..ac24e28a30 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -69,7 +69,6 @@ std::uint64_t cagra_udf_source_hash(const SampleFilterT& sample_filter) if (const auto* udf = get_cagra_udf_filter(sample_filter); udf != nullptr) { std::uint64_t seed = 0; seed = cagra_hash_combine(seed, std::hash{}(udf->source)); - seed = cagra_hash_combine(seed, std::hash{}(udf->cache_key)); seed = cagra_hash_combine(seed, std::hash{}(udf->function_name)); return seed; } diff --git a/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu b/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu index 10da68dbdd..c19e30ff9c 100644 --- a/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu +++ b/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu @@ -174,8 +174,7 @@ TEST_P(CagraUdfFilterTest, AcceptAllMatchesNoFilter) cuvs::neighbors::filtering::none_sample_filter no_filter; auto expected = search(no_filter, 0.0f); - cuvs::neighbors::filtering::udf_filter udf_filter( - accept_all_udf_source(), nullptr, 0.0f, "cagra-accept-all-filter"); + cuvs::neighbors::filtering::udf_filter udf_filter(accept_all_udf_source(), nullptr, 0.0f); auto actual = search(udf_filter); expect_same_results(expected, actual); @@ -183,8 +182,7 @@ TEST_P(CagraUdfFilterTest, AcceptAllMatchesNoFilter) TEST_P(CagraUdfFilterTest, RejectAllReturnsNoValidNeighbors) { - cuvs::neighbors::filtering::udf_filter udf_filter( - reject_all_udf_source(), nullptr, 0.999f, "cagra-reject-all-filter"); + cuvs::neighbors::filtering::udf_filter udf_filter(reject_all_udf_source(), nullptr, 0.999f); auto result = search(udf_filter); // CAGRA algorithms do not all normalize empty-result sentinels the same way. Single-CTA @@ -199,8 +197,7 @@ TEST_P(CagraUdfFilterTest, HighFilteringRateReturnsOnlyValidNeighbors) { cuvs::neighbors::filtering::udf_filter udf_filter(high_filtering_rate_udf_source(), nullptr, - high_filtering_rate, - "cagra-high-filtering-rate-filter"); + high_filtering_rate); auto result = search(udf_filter); for (auto source_id : result.neighbors) { @@ -210,10 +207,9 @@ TEST_P(CagraUdfFilterTest, HighFilteringRateReturnsOnlyValidNeighbors) } } -TEST_P(CagraUdfFilterTest, RepeatedUdfSearchWithSameCacheKeyMatches) +TEST_P(CagraUdfFilterTest, RepeatedUdfSearchWithSameSourceMatches) { - cuvs::neighbors::filtering::udf_filter udf_filter( - accept_all_udf_source(), nullptr, 0.0f, "cagra-repeat-cache-filter"); + cuvs::neighbors::filtering::udf_filter udf_filter(accept_all_udf_source(), nullptr, 0.0f); auto first = search(udf_filter); auto second = search(udf_filter); @@ -224,7 +220,7 @@ TEST_P(CagraUdfFilterTest, RepeatedUdfSearchWithSameCacheKeyMatches) TEST_P(CagraUdfFilterTest, InvalidSourceThrows) { cuvs::neighbors::filtering::udf_filter udf_filter( - "this is not valid cuda source", nullptr, 0.0f, "cagra-invalid-filter"); + "this is not valid cuda source", nullptr, 0.0f); EXPECT_THROW(search(udf_filter), std::exception); } @@ -245,7 +241,7 @@ TEST_P(CagraUdfFilterTest, ThresholdMatchesEquivalentBitset) auto expected = search(bitset_filter, filtering_rate); cuvs::neighbors::filtering::udf_filter udf_filter( - threshold_udf_source(), nullptr, filtering_rate, "cagra-threshold-filter"); + threshold_udf_source(), nullptr, filtering_rate); auto actual = search(udf_filter, filtering_rate); expect_same_results(expected, actual); @@ -276,7 +272,7 @@ TEST_P(CagraUdfFilterTest, TenantContextHonorsQuerySpecificMetadata) raft::resource::sync_stream(res); cuvs::neighbors::filtering::udf_filter udf_filter( - tenant_udf_source(), context.data_handle(), 2.0f / 3.0f, "cagra-tenant-filter"); + tenant_udf_source(), context.data_handle(), 2.0f / 3.0f); auto result = search(udf_filter); for (int64_t q = 0; q < n_queries; ++q) { diff --git a/examples/cpp/src/cagra_filter_udf_example.cu b/examples/cpp/src/cagra_filter_udf_example.cu index 2726db9bf3..0ab42dd580 100644 --- a/examples/cpp/src/cagra_filter_udf_example.cu +++ b/examples/cpp/src/cagra_filter_udf_example.cu @@ -215,7 +215,7 @@ int main() auto run_filter = [&](char const* label, char const* function_name, float filtering_rate, auto is_valid) { auto filter = cuvs::neighbors::filtering::udf_filter( - source, context_device.data_handle(), filtering_rate, label, function_name); + source, context_device.data_handle(), filtering_rate, function_name); cuvs::neighbors::cagra::search(res, search_params, index, From 296d44c68654ba43e9e0ef64834979f13f144039 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Thu, 4 Jun 2026 15:35:12 -0500 Subject: [PATCH 03/11] ENH Generalize the CAGRA JIT filter payload header --- ...ra_bitset.cuh => cagra_filter_payload.cuh} | 25 +++++++++++-------- .../cagra/jit_lto_kernels/kernel_def.hpp | 2 +- .../jit_lto_kernels/sample_filter_udf.cuh | 2 +- .../jit_lto_kernels/search_multi_cta_jit.cuh | 2 +- .../jit_lto_kernels/search_multi_jit.cuh | 2 +- .../jit_lto_kernels/search_single_cta_jit.cuh | 2 +- .../search_multi_kernel_launcher_jit.cuh | 2 +- .../search_single_cta_kernel_launcher_jit.cuh | 11 +++++--- .../detail/cagra/shared_launcher_jit.hpp | 5 ++-- 9 files changed, 31 insertions(+), 22 deletions(-) rename cpp/src/neighbors/detail/cagra/jit_lto_kernels/{cagra_bitset.cuh => cagra_filter_payload.cuh} (74%) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh similarity index 74% rename from cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh rename to cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh index 7c5e681bca..07f9d26d83 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh @@ -13,14 +13,14 @@ namespace cuvs::neighbors::cagra::detail { template -using cagra_bitset = cuvs::neighbors::detail::bitset_filter_data_t; +using cagra_filter_data_storage = ::cuvs::neighbors::detail::bitset_filter_data_t; enum class cagra_filter_kind : std::uint32_t { none = 0, bitset = 1, udf = 2 }; /// Host/device payload for linked CAGRA sample filters plus query offset for wrapped filters. template struct cagra_sample_filter { - cagra_bitset bitset{}; + cagra_filter_data_storage filter_data_storage{}; void* filter_data{nullptr}; cagra_filter_kind filter_kind{cagra_filter_kind::none}; std::uint32_t query_id_offset{0}; @@ -30,25 +30,26 @@ template struct is_bitset_filter : std::false_type {}; template -struct is_bitset_filter> +struct is_bitset_filter<::cuvs::neighbors::filtering::bitset_filter> : std::true_type {}; template struct is_udf_filter : std::false_type {}; template <> -struct is_udf_filter : std::true_type {}; +struct is_udf_filter<::cuvs::neighbors::filtering::udf_filter> : std::true_type {}; template void fill_cagra_sample_filter(cagra_sample_filter& out, const FilterT& filter) { using DecayedFilter = std::decay_t; if constexpr (is_bitset_filter::value) { - const auto bitset_view = filter.view(); - out.bitset.bitset_ptr = const_cast(bitset_view.data()); - out.bitset.bitset_len = static_cast(bitset_view.size()); - out.bitset.original_nbits = static_cast(bitset_view.get_original_nbits()); - out.filter_kind = cagra_filter_kind::bitset; + const auto bitset_view = filter.view(); + out.filter_data_storage.bitset_ptr = const_cast(bitset_view.data()); + out.filter_data_storage.bitset_len = static_cast(bitset_view.size()); + out.filter_data_storage.original_nbits = static_cast( + bitset_view.get_original_nbits()); + out.filter_kind = cagra_filter_kind::bitset; } else if constexpr (is_udf_filter::value) { out.filter_data = filter.filter_data; out.filter_kind = cagra_filter_kind::udf; @@ -73,7 +74,7 @@ cagra_sample_filter extract_cagra_sample_filter(const SampleFilter } template -const cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( +const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( const SampleFilterT& sample_filter) { using DecayedFilter = std::decay_t; @@ -91,7 +92,9 @@ __device__ __forceinline__ void* get_cagra_sample_filter_data( cagra_sample_filter& payload) { if (payload.filter_kind == cagra_filter_kind::udf) { return payload.filter_data; } - if (payload.filter_kind == cagra_filter_kind::bitset) { return &payload.bitset; } + if (payload.filter_kind == cagra_filter_kind::bitset) { + return &payload.filter_data_storage; + } return nullptr; } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp index 58c32707fe..c9702d3b72 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp @@ -12,7 +12,7 @@ #include #include "../compute_distance.hpp" // dataset_descriptor_base_t -#include "cagra_bitset.cuh" +#include "cagra_filter_payload.cuh" #include "search_single_cta_device_helpers.cuh" namespace cuvs::neighbors::cagra::detail { diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh index d69d967287..b1381c30f1 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh @@ -5,7 +5,7 @@ #pragma once -#include "cagra_bitset.cuh" +#include "cagra_filter_payload.cuh" #include #include diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh index 446b15cc3e..e6e1139a59 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh @@ -20,7 +20,7 @@ #include #endif -#include "cagra_bitset.cuh" +#include "cagra_filter_payload.cuh" #include "device_common_jit.cuh" #include "extern_device_functions.cuh" diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh index b580db0b80..157d31de29 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh @@ -8,7 +8,7 @@ #include "../../neighbors_device_intrinsics.cuh" #include "../hashmap.hpp" #include "../utils.hpp" -#include "cagra_bitset.cuh" +#include "cagra_filter_payload.cuh" #include #include diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh index 69098c75e5..0c9be9055a 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh @@ -29,7 +29,7 @@ #endif // Include extern function declarations before namespace so they're available to kernel definitions -#include "cagra_bitset.cuh" +#include "cagra_filter_payload.cuh" #include "extern_device_functions.cuh" // Include shared JIT device functions #include "device_common_jit.cuh" diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh index 4d7b087b80..f2fa0a327b 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh @@ -13,7 +13,7 @@ #include "jit_lto_kernels/kernel_def.hpp" #include "jit_lto_kernels/search_multi_kernel_planner.hpp" #include "search_plan.cuh" // For search_params -#include "shared_launcher_jit.hpp" // cagra_bitset / cagra_sample_filter, sample_filter_jit_tag_t, tags +#include "shared_launcher_jit.hpp" // sample-filter payload helpers and JIT tags #include #include #include diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh index ac24e28a30..d6c25a3b52 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -81,9 +81,14 @@ std::uint64_t cagra_sample_filter_hash(const SampleFilterT& sample_filter) const auto payload = extract_cagra_sample_filter(sample_filter); std::uint64_t seed = static_cast(payload.filter_kind); seed = cagra_hash_combine( - seed, static_cast(reinterpret_cast(payload.bitset.bitset_ptr))); - seed = cagra_hash_combine(seed, static_cast(payload.bitset.bitset_len)); - seed = cagra_hash_combine(seed, static_cast(payload.bitset.original_nbits)); + seed, + static_cast( + reinterpret_cast(payload.filter_data_storage.bitset_ptr))); + seed = cagra_hash_combine(seed, + static_cast(payload.filter_data_storage.bitset_len)); + seed = cagra_hash_combine(seed, + static_cast( + payload.filter_data_storage.original_nbits)); seed = cagra_hash_combine( seed, static_cast(reinterpret_cast(payload.filter_data))); seed = cagra_hash_combine(seed, static_cast(payload.query_id_offset)); diff --git a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp index d6de21644f..ecb744ae30 100644 --- a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp +++ b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp @@ -9,7 +9,7 @@ #include #include "../../sample_filter.cuh" // For none_sample_filter, bitset_filter -#include "jit_lto_kernels/cagra_bitset.cuh" // is_bitset_filter, cagra_bitset, cagra_sample_filter, extract +#include "jit_lto_kernels/cagra_filter_payload.cuh" // sample-filter payload helpers #include #include @@ -115,7 +115,8 @@ struct sample_filter_jit_tag { cagra_jit_sample_filter_tag_type_always_false, "CAGRA JIT: sample_filter_jit_tag does not know how to link this filter. " "CagraSampleFilterWithQueryIdOffset requires Inner to be a supported " - "built-in filter or udf_filter (see cagra_bitset.cuh and sample_filter_utils.cuh). " + "built-in filter or udf_filter (see cagra_filter_payload.cuh and " + "sample_filter_utils.cuh). " "For a new filter kind, add a sample_filter_jit_tag branch. " "(SAMPLE_FILTER_T in error; check InnerFilter in compiler output.)"); } From 67a03ddc2b10ce1ee625f803570799a8cd820409 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Thu, 4 Jun 2026 15:39:11 -0500 Subject: [PATCH 04/11] ENH Assign CAGRA bitset filter metadata as opaque payload storage --- .../cagra/jit_lto_kernels/cagra_filter_payload.cuh | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh index 07f9d26d83..117a13804d 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh @@ -44,11 +44,12 @@ void fill_cagra_sample_filter(cagra_sample_filter& out, const Filt { using DecayedFilter = std::decay_t; if constexpr (is_bitset_filter::value) { - const auto bitset_view = filter.view(); - out.filter_data_storage.bitset_ptr = const_cast(bitset_view.data()); - out.filter_data_storage.bitset_len = static_cast(bitset_view.size()); - out.filter_data_storage.original_nbits = static_cast( - bitset_view.get_original_nbits()); + const auto bitset_view = filter.view(); + out.filter_data_storage = + cagra_filter_data_storage{const_cast(bitset_view.data()), + static_cast(bitset_view.size()), + static_cast( + bitset_view.get_original_nbits())}; out.filter_kind = cagra_filter_kind::bitset; } else if constexpr (is_udf_filter::value) { out.filter_data = filter.filter_data; @@ -93,6 +94,7 @@ __device__ __forceinline__ void* get_cagra_sample_filter_data( { if (payload.filter_kind == cagra_filter_kind::udf) { return payload.filter_data; } if (payload.filter_kind == cagra_filter_kind::bitset) { + // The payload is passed by value to kernels; take the embedded storage address on device. return &payload.filter_data_storage; } return nullptr; From 1365f929d0b9d15a0375c4b5aa59db29da6aac44 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Thu, 4 Jun 2026 15:41:47 -0500 Subject: [PATCH 05/11] ENH Remove runtime filter kind from CAGRA JIT payload --- .../jit_lto_kernels/cagra_filter_payload.cuh | 9 ++------- .../search_single_cta_kernel_launcher_jit.cuh | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh index 117a13804d..abc459482f 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh @@ -15,14 +15,11 @@ namespace cuvs::neighbors::cagra::detail { template using cagra_filter_data_storage = ::cuvs::neighbors::detail::bitset_filter_data_t; -enum class cagra_filter_kind : std::uint32_t { none = 0, bitset = 1, udf = 2 }; - /// Host/device payload for linked CAGRA sample filters plus query offset for wrapped filters. template struct cagra_sample_filter { cagra_filter_data_storage filter_data_storage{}; void* filter_data{nullptr}; - cagra_filter_kind filter_kind{cagra_filter_kind::none}; std::uint32_t query_id_offset{0}; }; @@ -50,10 +47,8 @@ void fill_cagra_sample_filter(cagra_sample_filter& out, const Filt static_cast(bitset_view.size()), static_cast( bitset_view.get_original_nbits())}; - out.filter_kind = cagra_filter_kind::bitset; } else if constexpr (is_udf_filter::value) { out.filter_data = filter.filter_data; - out.filter_kind = cagra_filter_kind::udf; } } @@ -92,8 +87,8 @@ template __device__ __forceinline__ void* get_cagra_sample_filter_data( cagra_sample_filter& payload) { - if (payload.filter_kind == cagra_filter_kind::udf) { return payload.filter_data; } - if (payload.filter_kind == cagra_filter_kind::bitset) { + if (payload.filter_data != nullptr) { return payload.filter_data; } + if (payload.filter_data_storage.bitset_ptr != nullptr) { // The payload is passed by value to kernels; take the embedded storage address on device. return &payload.filter_data_storage; } diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh index d6c25a3b52..595a705726 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -75,11 +75,26 @@ std::uint64_t cagra_udf_source_hash(const SampleFilterT& sample_filter) return 0; } +template +std::uint64_t cagra_sample_filter_type_id(const SampleFilterT& sample_filter) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_udf_filter::value) { + return 2; + } else if constexpr (is_bitset_filter::value) { + return 1; + } else if constexpr (requires { sample_filter.filter; }) { + return cagra_sample_filter_type_id(sample_filter.filter); + } else { + return 0; + } +} + template std::uint64_t cagra_sample_filter_hash(const SampleFilterT& sample_filter) { const auto payload = extract_cagra_sample_filter(sample_filter); - std::uint64_t seed = static_cast(payload.filter_kind); + std::uint64_t seed = cagra_sample_filter_type_id(sample_filter); seed = cagra_hash_combine( seed, static_cast( From 95aabc36b40bfead447110c71dc77933c1754d66 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Thu, 4 Jun 2026 15:45:05 -0500 Subject: [PATCH 06/11] ENH Move CAGRA sample filter data selection into the payload --- .../jit_lto_kernels/cagra_filter_payload.cuh | 22 +++++++++---------- .../jit_lto_kernels/search_multi_cta_jit.cuh | 4 ++-- .../jit_lto_kernels/search_multi_jit.cuh | 4 ++-- .../jit_lto_kernels/search_single_cta_jit.cuh | 4 ++-- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh index abc459482f..de96a7c7cc 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh @@ -21,6 +21,16 @@ struct cagra_sample_filter { cagra_filter_data_storage filter_data_storage{}; void* filter_data{nullptr}; std::uint32_t query_id_offset{0}; + + __device__ __forceinline__ void* sample_filter_data() + { + if (filter_data != nullptr) { return filter_data; } + if (filter_data_storage.bitset_ptr != nullptr) { + // The payload is passed by value to kernels; take the embedded storage address on device. + return &filter_data_storage; + } + return nullptr; + } }; template @@ -83,16 +93,4 @@ const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( } } -template -__device__ __forceinline__ void* get_cagra_sample_filter_data( - cagra_sample_filter& payload) -{ - if (payload.filter_data != nullptr) { return payload.filter_data; } - if (payload.filter_data_storage.bitset_ptr != nullptr) { - // The payload is passed by value to kernels; take the embedded storage address on device. - return &payload.filter_data_storage; - } - return nullptr; -} - } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh index e6e1139a59..4c4f2e4f62 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh @@ -261,7 +261,7 @@ __device__ void search_kernel_jit( const auto parent_id = result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; if (!sample_filter(query_id + query_id_offset, to_source_index(parent_id), - get_cagra_sample_filter_data(filter_payload))) { + filter_payload.sample_filter_data())) { // If the parent must not be in the resulting top-k list, remove from the parent list result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); result_indices_buffer[parent_indices_buffer[p]] = invalid_index; @@ -280,7 +280,7 @@ __device__ void search_kernel_jit( index &= ~index_msb_1_mask; if (!sample_filter(query_id + query_id_offset, to_source_index(index), - get_cagra_sample_filter_data(filter_payload))) { + filter_payload.sample_filter_data())) { result_indices_buffer[i] = invalid_index; result_distances_buffer[i] = utils::get_max_value(); } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh index 157d31de29..f2495c0ad6 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh @@ -170,7 +170,7 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( ? static_cast(parent_index) : static_cast(source_indices_ptr[parent_index]); if (!sample_filter( - filter_query_id, node_id, get_cagra_sample_filter_data(filter_payload))) { + filter_query_id, node_id, filter_payload.sample_filter_data())) { parent_candidates_ptr[parent_list_index + (lds * local_query_id)] = utils::get_max_value(); parent_distance_ptr[parent_list_index + (lds * local_query_id)] = @@ -206,7 +206,7 @@ __device__ void apply_filter_kernel_jit( : source_indices_ptr[result_indices_ptr[index]]; if (!sample_filter( - query_id_offset + j, node_id, get_cagra_sample_filter_data(filter_payload))) { + query_id_offset + j, node_id, filter_payload.sample_filter_data())) { result_indices_ptr[index] = utils::get_max_value(); result_distances_ptr[index] = utils::get_max_value(); } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh index 0c9be9055a..d481c44946 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh @@ -306,7 +306,7 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core( const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; if (!sample_filter(query_id + query_id_offset, to_source_index(parent_id), - get_cagra_sample_filter_data(filter_payload))) { + filter_payload.sample_filter_data())) { result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); result_indices_buffer[parent_list_buffer[p]] = invalid_index; *filter_flag = 1; @@ -327,7 +327,7 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core( if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id + query_id_offset, to_source_index(node_id), - get_cagra_sample_filter_data(filter_payload))) { + filter_payload.sample_filter_data())) { result_distances_buffer[i] = utils::get_max_value(); result_indices_buffer[i] = invalid_index; } From 51716a6b375c00423a59dc01086ea7f21d6b7389 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Thu, 4 Jun 2026 15:46:55 -0500 Subject: [PATCH 07/11] DOC Document CAGRA UDF metadata offset handling --- .../detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh index de96a7c7cc..343f3570dc 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh @@ -79,6 +79,8 @@ cagra_sample_filter extract_cagra_sample_filter(const SampleFilter return out; } +/// Host: find UDF compile/link metadata only. Query offsets stay in the runtime payload produced +/// by @ref extract_cagra_sample_filter and are applied before calling the linked sample_filter. template const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( const SampleFilterT& sample_filter) From 3f021bb9c9cbdc6c5e93880050a45a7668be201e Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Thu, 4 Jun 2026 17:58:22 -0500 Subject: [PATCH 08/11] ENH Cache CAGRA filter payloads behind owned device pointers --- .../jit_lto_kernels/cagra_filter_payload.cuh | 145 +++++++++++++++--- .../search_multi_cta_kernel_launcher_jit.cuh | 7 +- .../search_multi_kernel_launcher_jit.cuh | 12 +- .../search_single_cta_kernel_launcher_jit.cuh | 35 ++--- 4 files changed, 154 insertions(+), 45 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh index 343f3570dc..0e17084e39 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh @@ -7,29 +7,132 @@ #include "../../../sample_filter.cuh" // public filter types #include "../../sample_filter_data.cuh" +#if !defined(__CUDACC_RTC__) +#include + +#include +#endif + +#include #include #include +#if !defined(__CUDACC_RTC__) +#include +#include +#include +#include +#endif + namespace cuvs::neighbors::cagra::detail { template using cagra_filter_data_storage = ::cuvs::neighbors::detail::bitset_filter_data_t; -/// Host/device payload for linked CAGRA sample filters plus query offset for wrapped filters. +/// Device payload for linked CAGRA sample filters plus query offset for wrapped filters. template struct cagra_sample_filter { - cagra_filter_data_storage filter_data_storage{}; void* filter_data{nullptr}; std::uint32_t query_id_offset{0}; - __device__ __forceinline__ void* sample_filter_data() + __device__ __forceinline__ void* sample_filter_data() { return filter_data; } +}; + +#if !defined(__CUDACC_RTC__) + +template +std::uint64_t cagra_payload_hash(PayloadT const& payload) +{ + static_assert(std::is_trivially_copyable_v); + constexpr std::uint64_t kOffset = 1469598103934665603ull; + constexpr std::uint64_t kPrime = 1099511628211ull; + auto const* bytes = reinterpret_cast(&payload); + std::uint64_t hash = kOffset; + for (std::size_t i = 0; i < sizeof(PayloadT); ++i) { + hash ^= bytes[i]; + hash *= kPrime; + } + return hash; +} + +template +struct cagra_device_payload_owner { + struct state { + PayloadT host_payload{}; + PayloadT* device_payload{nullptr}; + cudaStream_t stream{}; + std::mutex mutex; + + explicit state(PayloadT payload) : host_payload(payload) {} + + ~state() noexcept + { + if (device_payload != nullptr) { + RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(device_payload, stream)); + } + } + + PayloadT* dev_ptr(cudaStream_t cuda_stream) + { + std::lock_guard lock(mutex); + if (device_payload == nullptr) { + RAFT_CUDA_TRY( + cudaMallocAsync(reinterpret_cast(&device_payload), sizeof(PayloadT), cuda_stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync(device_payload, + &host_payload, + sizeof(PayloadT), + cudaMemcpyHostToDevice, + cuda_stream)); + stream = cuda_stream; + } + return device_payload; + } + }; + + cagra_device_payload_owner() = default; + + explicit cagra_device_payload_owner(PayloadT payload) { - if (filter_data != nullptr) { return filter_data; } - if (filter_data_storage.bitset_ptr != nullptr) { - // The payload is passed by value to kernels; take the embedded storage address on device. - return &filter_data_storage; + static std::mutex cache_mutex; + static std::unordered_map> cache; + + const auto key = cagra_payload_hash(payload); + std::lock_guard lock(cache_mutex); + if (auto it = cache.find(key); it != cache.end()) { + if (auto cached = it->second; + std::memcmp(&cached->host_payload, &payload, sizeof(PayloadT)) == 0) { + state_ = std::move(cached); + return; + } } - return nullptr; + state_ = std::make_shared(payload); + cache[key] = state_; + } + + void* dev_ptr(cudaStream_t stream) const + { + return state_ == nullptr ? nullptr : state_->dev_ptr(stream); + } + + PayloadT const* host_payload() const + { + return state_ == nullptr ? nullptr : &state_->host_payload; + } + + private: + std::shared_ptr state_; +}; + +template +struct cagra_sample_filter_payload { + cagra_sample_filter payload{}; + cagra_device_payload_owner> storage_owner{}; + + cagra_sample_filter device_payload(cudaStream_t stream) const + { + auto out = payload; + if (out.filter_data == nullptr) { out.filter_data = storage_owner.dev_ptr(stream); } + return out; } }; @@ -47,31 +150,33 @@ template <> struct is_udf_filter<::cuvs::neighbors::filtering::udf_filter> : std::true_type {}; template -void fill_cagra_sample_filter(cagra_sample_filter& out, const FilterT& filter) +void fill_cagra_sample_filter(cagra_sample_filter_payload& out, const FilterT& filter) { using DecayedFilter = std::decay_t; if constexpr (is_bitset_filter::value) { const auto bitset_view = filter.view(); - out.filter_data_storage = - cagra_filter_data_storage{const_cast(bitset_view.data()), - static_cast(bitset_view.size()), - static_cast( - bitset_view.get_original_nbits())}; + out.storage_owner = + cagra_device_payload_owner>{ + cagra_filter_data_storage{const_cast(bitset_view.data()), + static_cast(bitset_view.size()), + static_cast( + bitset_view.get_original_nbits())}}; } else if constexpr (is_udf_filter::value) { - out.filter_data = filter.filter_data; + out.payload.filter_data = filter.filter_data; } } -/// Host: fill @ref cagra_sample_filter from a CAGRA filter object (used by JIT LTO launchers). +/// Host: fill @ref cagra_sample_filter_payload from a CAGRA filter object. template -cagra_sample_filter extract_cagra_sample_filter(const SampleFilterT& sample_filter) +cagra_sample_filter_payload extract_cagra_sample_filter( + const SampleFilterT& sample_filter) { - cagra_sample_filter out; + cagra_sample_filter_payload out; if constexpr (requires { sample_filter.filter; sample_filter.offset; }) { - out.query_id_offset = sample_filter.offset; + out.payload.query_id_offset = sample_filter.offset; fill_cagra_sample_filter(out, sample_filter.filter); } else { fill_cagra_sample_filter(out, sample_filter); @@ -95,4 +200,6 @@ const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( } } +#endif // !defined(__CUDACC_RTC__) + } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh index 82bdd8ac93..e590015b8e 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh @@ -56,8 +56,9 @@ void select_and_run(const dataset_descriptor_host& dat SampleFilterT sample_filter, cudaStream_t stream) { - const auto bf = extract_cagra_sample_filter(sample_filter); - const uint32_t query_id_offset = bf.query_id_offset; + const auto filter_payload_owner = extract_cagra_sample_filter(sample_filter); + const auto filter_payload = filter_payload_owner.device_payload(stream); + const uint32_t query_id_offset = filter_payload.query_id_offset; std::shared_ptr launcher = make_cagra_multi_cta_jit_launcher& dat num_executed_iterations, static_cast(graph.extent(0)), query_id_offset, - bf); + filter_payload); }; cuvs::neighbors::detail::safely_launch_kernel_with_smem_size< multi_cta_search::search_multi_cta_kernel_func_t>( diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh index f2fa0a327b..5f5d6432c6 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh @@ -111,7 +111,8 @@ void compute_distance_to_child_nodes_jit( cudaStream_t cuda_stream, std::shared_ptr const& launcher) { - const auto bf = extract_cagra_sample_filter(sample_filter); + const auto filter_payload_owner = extract_cagra_sample_filter(sample_filter); + const auto filter_payload = filter_payload_owner.device_payload(cuda_stream); const auto block_size = 128; const auto teams_per_block = block_size / dataset_desc.team_size; @@ -142,7 +143,7 @@ void compute_distance_to_child_nodes_jit( result_indices_ptr, result_distances_ptr, ldd, - bf); + filter_payload); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -160,8 +161,9 @@ void apply_filter_jit(const SourceIndexT* source_indices_ptr, cudaStream_t cuda_stream, std::shared_ptr const& launcher) { - const auto bf = extract_cagra_sample_filter(sample_filter); - const auto effective_query_id_offset = query_id_offset + bf.query_id_offset; + const auto filter_payload_owner = extract_cagra_sample_filter(sample_filter); + const auto filter_payload = filter_payload_owner.device_payload(cuda_stream); + const auto effective_query_id_offset = query_id_offset + filter_payload.query_id_offset; const std::uint32_t block_size = 256; const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); @@ -181,7 +183,7 @@ void apply_filter_jit(const SourceIndexT* source_indices_ptr, result_buffer_size, num_queries, effective_query_id_offset, - bf); + filter_payload); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh index 595a705726..9ca48973e1 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -93,20 +93,18 @@ std::uint64_t cagra_sample_filter_type_id(const SampleFilterT& sample_filter) template std::uint64_t cagra_sample_filter_hash(const SampleFilterT& sample_filter) { - const auto payload = extract_cagra_sample_filter(sample_filter); + const auto payload_owner = extract_cagra_sample_filter(sample_filter); std::uint64_t seed = cagra_sample_filter_type_id(sample_filter); - seed = cagra_hash_combine( - seed, - static_cast( - reinterpret_cast(payload.filter_data_storage.bitset_ptr))); - seed = cagra_hash_combine(seed, - static_cast(payload.filter_data_storage.bitset_len)); - seed = cagra_hash_combine(seed, - static_cast( - payload.filter_data_storage.original_nbits)); + if (const auto* storage = payload_owner.storage_owner.host_payload(); storage != nullptr) { + seed = cagra_hash_combine( + seed, static_cast(reinterpret_cast(storage->bitset_ptr))); + seed = cagra_hash_combine(seed, static_cast(storage->bitset_len)); + seed = cagra_hash_combine(seed, static_cast(storage->original_nbits)); + } seed = cagra_hash_combine( - seed, static_cast(reinterpret_cast(payload.filter_data))); - seed = cagra_hash_combine(seed, static_cast(payload.query_id_offset)); + seed, + static_cast(reinterpret_cast(payload_owner.payload.filter_data))); + seed = cagra_hash_combine(seed, static_cast(payload_owner.payload.query_id_offset)); seed = cagra_hash_combine(seed, cagra_udf_source_hash(sample_filter)); return seed; } @@ -507,6 +505,7 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn rmm::device_uvector hashmap; std::atomic> last_touch; uint64_t param_hash; + cagra_sample_filter_payload filter_payload_owner; cagra_sample_filter filter_payload; static inline auto calculate_parameter_hash( @@ -625,9 +624,9 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps)) { - const auto bf = extract_cagra_sample_filter(sample_filter); - this->filter_payload = bf; - const uint32_t query_id_offset = bf.query_id_offset; + filter_payload_owner = extract_cagra_sample_filter(sample_filter); + this->filter_payload = filter_payload_owner.device_payload(stream); + const uint32_t query_id_offset = filter_payload.query_id_offset; // set kernel launch parameters dim3 gs = calc_coop_grid_size(block_size, smem_size, persistent_device_usage); @@ -815,9 +814,9 @@ void select_and_run( const SourceIndexT* source_indices_ptr = source_indices.has_value() ? source_indices->data_handle() : nullptr; - const auto bf = extract_cagra_sample_filter(sample_filter); - const auto filter_payload = bf; - const uint32_t query_id_offset = bf.query_id_offset; + const auto filter_payload_owner = extract_cagra_sample_filter(sample_filter); + const auto filter_payload = filter_payload_owner.device_payload(stream); + const uint32_t query_id_offset = filter_payload.query_id_offset; // Use common logic to compute launch config auto config = compute_launch_config(num_itopk_candidates, ps.itopk_size, block_size); From 5407d06b6319ed29f0a0a76a689c8f67ea1614ab Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Thu, 4 Jun 2026 19:14:14 -0500 Subject: [PATCH 09/11] FIX Cache CAGRA filter payloads with stream-ordered device reuse --- .../jit_lto_kernels/cagra_filter_payload.cuh | 75 +++++++++++++++---- 1 file changed, 59 insertions(+), 16 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh index 0e17084e39..b3c749ec59 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #endif namespace cuvs::neighbors::cagra::detail { @@ -61,6 +62,8 @@ struct cagra_device_payload_owner { PayloadT host_payload{}; PayloadT* device_payload{nullptr}; cudaStream_t stream{}; + cudaEvent_t ready_event{}; + int device{-1}; std::mutex mutex; explicit state(PayloadT payload) : host_payload(payload) {} @@ -70,12 +73,14 @@ struct cagra_device_payload_owner { if (device_payload != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(device_payload, stream)); } + if (ready_event != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(ready_event)); } } PayloadT* dev_ptr(cudaStream_t cuda_stream) { std::lock_guard lock(mutex); if (device_payload == nullptr) { + RAFT_CUDA_TRY(cudaGetDevice(&device)); RAFT_CUDA_TRY( cudaMallocAsync(reinterpret_cast(&device_payload), sizeof(PayloadT), cuda_stream)); RAFT_CUDA_TRY(cudaMemcpyAsync(device_payload, @@ -83,35 +88,73 @@ struct cagra_device_payload_owner { sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream)); + RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming)); + RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream)); stream = cuda_stream; + } else { + RAFT_CUDA_TRY(cudaStreamWaitEvent(cuda_stream, ready_event, 0)); } return device_payload; } }; + struct cache_key { + std::uint64_t payload_hash{}; + int device{}; + + bool operator==(cache_key const& other) const + { + return payload_hash == other.payload_hash && device == other.device; + } + }; + + struct cache_key_hash { + std::size_t operator()(cache_key const& key) const + { + auto seed = static_cast(key.payload_hash); + seed ^= static_cast(key.device) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; + } + }; + cagra_device_payload_owner() = default; explicit cagra_device_payload_owner(PayloadT payload) + : state_{std::make_shared(payload)} { - static std::mutex cache_mutex; - static std::unordered_map> cache; - - const auto key = cagra_payload_hash(payload); - std::lock_guard lock(cache_mutex); - if (auto it = cache.find(key); it != cache.end()) { - if (auto cached = it->second; - std::memcmp(&cached->host_payload, &payload, sizeof(PayloadT)) == 0) { - state_ = std::move(cached); - return; - } - } - state_ = std::make_shared(payload); - cache[key] = state_; } void* dev_ptr(cudaStream_t stream) const { - return state_ == nullptr ? nullptr : state_->dev_ptr(stream); + if (state_ == nullptr) { return nullptr; } + + int device{}; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + + static std::mutex cache_mutex; + // Keep cached payload copies for process lifetime to avoid per-search allocation/copy churn. + // Cross-stream reuse is ordered by each state's ready_event before kernels consume the pointer. + static std::unordered_map>, cache_key_hash> cache; + + const auto key = cache_key{cagra_payload_hash(state_->host_payload), device}; + std::shared_ptr selected_state; + { + std::lock_guard lock(cache_mutex); + auto& entries = cache[key]; + for (auto const& cached : entries) { + if (std::memcmp(&cached->host_payload, &state_->host_payload, sizeof(PayloadT)) == 0) { + selected_state = cached; + break; + } + } + if (selected_state == nullptr) { + selected_state = state_; + entries.push_back(selected_state); + } + } + + state_ = std::move(selected_state); + return state_->dev_ptr(stream); } PayloadT const* host_payload() const @@ -120,7 +163,7 @@ struct cagra_device_payload_owner { } private: - std::shared_ptr state_; + mutable std::shared_ptr state_; }; template From ee1475a33527cb929b00617e27a5dc2eb3dc63b8 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Fri, 5 Jun 2026 09:07:22 -0500 Subject: [PATCH 10/11] FIX Style fixes --- cpp/include/cuvs/neighbors/common.hpp | 4 +-- cpp/src/neighbors/cagra.cuh | 10 ++++---- .../jit_lto_kernels/apply_filter_kernel.cu.in | 6 ++--- .../jit_lto_kernels/cagra_filter_payload.cuh | 25 ++++++++----------- .../cagra_jit_launcher_factory.hpp | 4 +-- .../jit_lto_kernels/search_multi_jit.cuh | 14 +++++------ .../detail/cagra/search_multi_kernel.cuh | 3 +-- .../search_multi_kernel_launcher_jit.cuh | 4 +-- .../search_single_cta_kernel_launcher_jit.cuh | 11 ++++---- .../detail/cagra/shared_launcher_jit.hpp | 5 ++-- .../neighbors/ann_cagra/test_filter_udf.cu | 8 +++--- 11 files changed, 43 insertions(+), 51 deletions(-) diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 8d3b04810f..2fd804f115 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -644,8 +644,8 @@ struct udf_filter : public base_filter { udf_filter() = default; explicit udf_filter(std::string source, - void* filter_data = nullptr, - float filtering_rate = -1.0f, + void* filter_data = nullptr, + float filtering_rate = -1.0f, std::string function_name = "cuvs_filter_udf") : source(std::move(source)), filter_data(filter_data), diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index ffe33ecd48..ee87c2c0ab 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -392,11 +392,11 @@ void search(raft::resources const& res, if (params.filtering_rate < 0.0) { const float min_filtering_rate = 0.0f; const float max_filtering_rate = 0.999f; - params_copy.filtering_rate = sample_filter.filtering_rate < 0.0f - ? 0.0f - : std::min(std::max(sample_filter.filtering_rate, - min_filtering_rate), - max_filtering_rate); + params_copy.filtering_rate = + sample_filter.filtering_rate < 0.0f + ? 0.0f + : std::min(std::max(sample_filter.filtering_rate, min_filtering_rate), + max_filtering_rate); } auto sample_filter_copy = sample_filter; return search_with_filtering( diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in index c6e4c49df8..5829072631 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in @@ -8,9 +8,9 @@ namespace { -using index_t = @index_type@; -using distance_t = @distance_type@; -using source_index_t = @source_index_type@; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh index b3c749ec59..725f779232 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh @@ -81,13 +81,10 @@ struct cagra_device_payload_owner { std::lock_guard lock(mutex); if (device_payload == nullptr) { RAFT_CUDA_TRY(cudaGetDevice(&device)); - RAFT_CUDA_TRY( - cudaMallocAsync(reinterpret_cast(&device_payload), sizeof(PayloadT), cuda_stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync(device_payload, - &host_payload, - sizeof(PayloadT), - cudaMemcpyHostToDevice, - cuda_stream)); + RAFT_CUDA_TRY(cudaMallocAsync( + reinterpret_cast(&device_payload), sizeof(PayloadT), cuda_stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync( + device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream)); RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming)); RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream)); stream = cuda_stream; @@ -119,8 +116,7 @@ struct cagra_device_payload_owner { cagra_device_payload_owner() = default; - explicit cagra_device_payload_owner(PayloadT payload) - : state_{std::make_shared(payload)} + explicit cagra_device_payload_owner(PayloadT payload) : state_{std::make_shared(payload)} { } @@ -198,12 +194,11 @@ void fill_cagra_sample_filter(cagra_sample_filter_payload& out, co using DecayedFilter = std::decay_t; if constexpr (is_bitset_filter::value) { const auto bitset_view = filter.view(); - out.storage_owner = - cagra_device_payload_owner>{ - cagra_filter_data_storage{const_cast(bitset_view.data()), - static_cast(bitset_view.size()), - static_cast( - bitset_view.get_original_nbits())}}; + out.storage_owner = cagra_device_payload_owner>{ + cagra_filter_data_storage{ + const_cast(bitset_view.data()), + static_cast(bitset_view.size()), + static_cast(bitset_view.get_original_nbits())}}; } else if constexpr (is_udf_filter::value) { out.payload.filter_data = filter.filter_data; } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp index 997f916d78..896817d869 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp @@ -7,10 +7,10 @@ #include "../compute_distance.hpp" #include "../shared_launcher_jit.hpp" +#include "sample_filter_udf.cuh" #include "search_multi_cta_planner.hpp" #include "search_multi_kernel_planner.hpp" #include "search_single_cta_planner.hpp" -#include "sample_filter_udf.cuh" #include #include @@ -337,7 +337,7 @@ std::shared_ptr make_cagra_multi_cta_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, std::move(sample_filter_udf_fragment)); + dataset_desc, std::move(sample_filter_udf_fragment)); } /// Build a JIT AlgorithmLauncher for multi-kernel CAGRA helpers that need `setup_workspace` and diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh index f2495c0ad6..a714ced5c2 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh @@ -113,8 +113,8 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( // Get team_size_bits directly from base descriptor uint32_t team_size_bits = dataset_desc->team_size_bitshift(); - const auto team_size = 1u << team_size_bits; - const uint32_t ldb = hashmap::get_size(hash_bitlen); + const auto team_size = 1u << team_size_bits; + const uint32_t ldb = hashmap::get_size(hash_bitlen); const auto tid = threadIdx.x + blockDim.x * blockIdx.x; const auto global_team_id = tid >> team_size_bits; const auto local_query_id = blockIdx.y; @@ -133,8 +133,7 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( if (parent_list_index == utils::get_max_value()) { return; } constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const auto raw_parent_index = - parent_candidates_ptr[parent_list_index + (lds * local_query_id)]; + const auto raw_parent_index = parent_candidates_ptr[parent_list_index + (lds * local_query_id)]; if (raw_parent_index == utils::get_max_value()) { result_distances_ptr[ldd * local_query_id + global_team_id] = @@ -167,10 +166,9 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( } const SourceIndexT node_id = source_indices_ptr == nullptr - ? static_cast(parent_index) - : static_cast(source_indices_ptr[parent_index]); - if (!sample_filter( - filter_query_id, node_id, filter_payload.sample_filter_data())) { + ? static_cast(parent_index) + : static_cast(source_indices_ptr[parent_index]); + if (!sample_filter(filter_query_id, node_id, filter_payload.sample_filter_data())) { parent_candidates_ptr[parent_list_index + (lds * local_query_id)] = utils::get_max_value(); parent_distance_ptr[parent_list_index + (lds * local_query_id)] = diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index 5016c4b962..e09ef82a39 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -617,8 +617,7 @@ struct search DISTANCE_T, SourceIndexT, sample_filter_jit_tag_t>( - dataset_desc, - make_cagra_sample_filter_udf_fragment(sample_filter)); + dataset_desc, make_cagra_sample_filter_udf_fragment(sample_filter)); apply_filter_jit( source_indices_ptr, diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh index 5f5d6432c6..a635a26b75 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh @@ -161,8 +161,8 @@ void apply_filter_jit(const SourceIndexT* source_indices_ptr, cudaStream_t cuda_stream, std::shared_ptr const& launcher) { - const auto filter_payload_owner = extract_cagra_sample_filter(sample_filter); - const auto filter_payload = filter_payload_owner.device_payload(cuda_stream); + const auto filter_payload_owner = extract_cagra_sample_filter(sample_filter); + const auto filter_payload = filter_payload_owner.device_payload(cuda_stream); const auto effective_query_id_offset = query_id_offset + filter_payload.query_id_offset; const std::uint32_t block_size = 256; diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh index 9ca48973e1..d995788a98 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -94,17 +94,18 @@ template std::uint64_t cagra_sample_filter_hash(const SampleFilterT& sample_filter) { const auto payload_owner = extract_cagra_sample_filter(sample_filter); - std::uint64_t seed = cagra_sample_filter_type_id(sample_filter); + std::uint64_t seed = cagra_sample_filter_type_id(sample_filter); if (const auto* storage = payload_owner.storage_owner.host_payload(); storage != nullptr) { seed = cagra_hash_combine( seed, static_cast(reinterpret_cast(storage->bitset_ptr))); seed = cagra_hash_combine(seed, static_cast(storage->bitset_len)); seed = cagra_hash_combine(seed, static_cast(storage->original_nbits)); } - seed = cagra_hash_combine( - seed, - static_cast(reinterpret_cast(payload_owner.payload.filter_data))); - seed = cagra_hash_combine(seed, static_cast(payload_owner.payload.query_id_offset)); + seed = cagra_hash_combine(seed, + static_cast( + reinterpret_cast(payload_owner.payload.filter_data))); + seed = + cagra_hash_combine(seed, static_cast(payload_owner.payload.query_id_offset)); seed = cagra_hash_combine(seed, cagra_udf_source_hash(sample_filter)); return seed; } diff --git a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp index ecb744ae30..15f73d7b30 100644 --- a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp +++ b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp @@ -8,7 +8,7 @@ #include #include -#include "../../sample_filter.cuh" // For none_sample_filter, bitset_filter +#include "../../sample_filter.cuh" // For none_sample_filter, bitset_filter #include "jit_lto_kernels/cagra_filter_payload.cuh" // sample-filter payload helpers #include @@ -106,7 +106,8 @@ struct sample_filter_jit_tag { using InnerFilter = decltype(std::declval().filter); if constexpr (is_bitset_filter>::value || std::is_same_v, bitset_filter> || - std::is_same_v, bitset_filter>) { + std::is_same_v, + bitset_filter>) { return cuvs::neighbors::detail::tag_filter_bitset{}; } else if constexpr (is_udf_filter>::value) { return cuvs::neighbors::detail::tag_filter_udf{}; diff --git a/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu b/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu index c19e30ff9c..ba65df5d65 100644 --- a/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu +++ b/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu @@ -195,9 +195,8 @@ TEST_P(CagraUdfFilterTest, RejectAllReturnsNoValidNeighbors) TEST_P(CagraUdfFilterTest, HighFilteringRateReturnsOnlyValidNeighbors) { - cuvs::neighbors::filtering::udf_filter udf_filter(high_filtering_rate_udf_source(), - nullptr, - high_filtering_rate); + cuvs::neighbors::filtering::udf_filter udf_filter( + high_filtering_rate_udf_source(), nullptr, high_filtering_rate); auto result = search(udf_filter); for (auto source_id : result.neighbors) { @@ -219,8 +218,7 @@ TEST_P(CagraUdfFilterTest, RepeatedUdfSearchWithSameSourceMatches) TEST_P(CagraUdfFilterTest, InvalidSourceThrows) { - cuvs::neighbors::filtering::udf_filter udf_filter( - "this is not valid cuda source", nullptr, 0.0f); + cuvs::neighbors::filtering::udf_filter udf_filter("this is not valid cuda source", nullptr, 0.0f); EXPECT_THROW(search(udf_filter), std::exception); } From aa016513a2cc12ceb32b2d0531a50d97a7460a2d Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Fri, 5 Jun 2026 17:03:18 -0500 Subject: [PATCH 11/11] FIX Cache CAGRA filter payload device state outside the JIT kernel ABI --- .../detail/cagra/cagra_filter_payload.hpp | 256 ++++++++++++++++++ .../jit_lto_kernels/cagra_filter_payload.cuh | 223 --------------- .../jit_lto_kernels/sample_filter_udf.cuh | 2 +- .../search_multi_cta_kernel_launcher_jit.cuh | 5 +- .../search_multi_kernel_launcher_jit.cuh | 6 +- .../search_single_cta_kernel_launcher_jit.cuh | 29 +- .../detail/cagra/shared_launcher_jit.hpp | 6 +- 7 files changed, 275 insertions(+), 252 deletions(-) create mode 100644 cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp diff --git a/cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp b/cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp new file mode 100644 index 0000000000..6a5ef6514c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp @@ -0,0 +1,256 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "../../sample_filter.cuh" // public filter types +#include "../sample_filter_data.cuh" +#include "jit_lto_kernels/cagra_filter_payload.cuh" + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +using cagra_filter_data_storage = ::cuvs::neighbors::detail::bitset_filter_data_t; + +template +std::uint64_t cagra_payload_hash(PayloadT const& payload) +{ + static_assert(std::is_trivially_copyable_v); + constexpr std::uint64_t kOffset = 1469598103934665603ull; + constexpr std::uint64_t kPrime = 1099511628211ull; + auto const* bytes = reinterpret_cast(&payload); + std::uint64_t hash = kOffset; + for (std::size_t i = 0; i < sizeof(PayloadT); ++i) { + hash ^= bytes[i]; + hash *= kPrime; + } + return hash; +} + +template +struct cagra_device_payload_owner { + struct state { + PayloadT host_payload{}; + PayloadT* device_payload{nullptr}; + cudaStream_t stream{}; + cudaEvent_t ready_event{}; + int device{-1}; + std::mutex mutex; + + explicit state(PayloadT payload) : host_payload(payload) {} + + ~state() noexcept + { + if (device_payload != nullptr) { + RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(device_payload, stream)); + } + if (ready_event != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(ready_event)); } + } + + PayloadT* dev_ptr(cudaStream_t cuda_stream) + { + std::lock_guard lock(mutex); + if (device_payload == nullptr) { + RAFT_CUDA_TRY(cudaGetDevice(&device)); + RAFT_CUDA_TRY(cudaMallocAsync( + reinterpret_cast(&device_payload), sizeof(PayloadT), cuda_stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync( + device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream)); + RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming)); + RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream)); + stream = cuda_stream; + } else { + RAFT_CUDA_TRY(cudaStreamWaitEvent(cuda_stream, ready_event, 0)); + } + return device_payload; + } + }; + + // PayloadT is copied to device by value. Pointer fields inside PayloadT are shallow-copied and + // must already point to device-addressable memory that remains valid while the cached payload is + // usable. + struct cache_key { + std::uint64_t payload_hash{}; + int device{}; + + bool operator==(cache_key const& other) const + { + return payload_hash == other.payload_hash && device == other.device; + } + }; + + struct cache_key_hash { + std::size_t operator()(cache_key const& key) const + { + auto seed = static_cast(key.payload_hash); + seed ^= static_cast(key.device) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; + } + }; + + cagra_device_payload_owner() = default; + + void* dev_ptr(PayloadT payload, cudaStream_t stream) const + { + int device{}; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + + // Keep cached payload copies for process lifetime to avoid per-search allocation/copy churn. + // Cross-stream reuse is ordered by each state's ready_event before kernels consume the pointer. + const auto key = cache_key{cagra_payload_hash(payload), device}; + std::shared_ptr selected_state; + { + std::lock_guard lock(cache_mutex_); + auto& entries = cache_[key]; + for (auto const& cached : entries) { + if (std::memcmp(&cached->host_payload, &payload, sizeof(PayloadT)) == 0) { + selected_state = cached; + break; + } + } + if (selected_state == nullptr) { + selected_state = std::make_shared(payload); + entries.push_back(selected_state); + } + } + + return selected_state->dev_ptr(stream); + } + + private: + mutable std::mutex cache_mutex_; + mutable std::unordered_map>, cache_key_hash> cache_; +}; + +template +struct is_bitset_filter : std::false_type {}; + +template +struct is_bitset_filter<::cuvs::neighbors::filtering::bitset_filter> + : std::true_type {}; + +template +struct is_udf_filter : std::false_type {}; + +template <> +struct is_udf_filter<::cuvs::neighbors::filtering::udf_filter> : std::true_type {}; + +template +cagra_filter_data_storage make_cagra_filter_data_storage(const FilterT& filter) +{ + const auto bitset_view = filter.view(); + return cagra_filter_data_storage{ + const_cast(bitset_view.data()), + static_cast(bitset_view.size()), + static_cast(bitset_view.get_original_nbits())}; +} + +template +void* get_cagra_device_payload(PayloadT payload, cudaStream_t stream) +{ + static cagra_device_payload_owner owner; + return owner.dev_ptr(payload, stream); +} + +template +void fill_cagra_sample_filter(cagra_sample_filter& out, + const FilterT& filter, + cudaStream_t stream) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_bitset_filter::value) { + out.filter_data = get_cagra_device_payload(make_cagra_filter_data_storage(filter), + stream); + } else if constexpr (is_udf_filter::value) { + out.filter_data = filter.filter_data; + } +} + +template +std::uint64_t cagra_filter_payload_hash(const FilterT& filter) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_bitset_filter::value) { + return cagra_payload_hash(make_cagra_filter_data_storage(filter)); + } else if constexpr (requires { filter.filter; }) { + return cagra_filter_payload_hash(filter.filter); + } else { + return 0; + } +} + +template +void* cagra_filter_data_ptr(const FilterT& filter) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_udf_filter::value) { + return filter.filter_data; + } else if constexpr (requires { filter.filter; }) { + return cagra_filter_data_ptr(filter.filter); + } else { + return nullptr; + } +} + +template +std::uint32_t cagra_filter_query_id_offset(const SampleFilterT& sample_filter) +{ + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + return sample_filter.offset; + } else { + return 0; + } +} + +/// Host: fill @ref cagra_sample_filter from a CAGRA filter object. +template +cagra_sample_filter extract_cagra_sample_filter(const SampleFilterT& sample_filter, + cudaStream_t stream) +{ + cagra_sample_filter out; + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + out.query_id_offset = sample_filter.offset; + fill_cagra_sample_filter(out, sample_filter.filter, stream); + } else { + fill_cagra_sample_filter(out, sample_filter, stream); + } + return out; +} + +/// Host: find UDF compile/link metadata only. Query offsets stay in the runtime payload produced +/// by @ref extract_cagra_sample_filter and are applied before calling the linked sample_filter. +template +const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( + const SampleFilterT& sample_filter) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_udf_filter::value) { + return &sample_filter; + } else if constexpr (requires { sample_filter.filter; }) { + return get_cagra_udf_filter(sample_filter.filter); + } else { + return nullptr; + } +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh index 725f779232..f4e24ad2dc 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh @@ -4,32 +4,10 @@ */ #pragma once -#include "../../../sample_filter.cuh" // public filter types -#include "../../sample_filter_data.cuh" - -#if !defined(__CUDACC_RTC__) -#include - -#include -#endif - -#include #include -#include - -#if !defined(__CUDACC_RTC__) -#include -#include -#include -#include -#include -#endif namespace cuvs::neighbors::cagra::detail { -template -using cagra_filter_data_storage = ::cuvs::neighbors::detail::bitset_filter_data_t; - /// Device payload for linked CAGRA sample filters plus query offset for wrapped filters. template struct cagra_sample_filter { @@ -39,205 +17,4 @@ struct cagra_sample_filter { __device__ __forceinline__ void* sample_filter_data() { return filter_data; } }; -#if !defined(__CUDACC_RTC__) - -template -std::uint64_t cagra_payload_hash(PayloadT const& payload) -{ - static_assert(std::is_trivially_copyable_v); - constexpr std::uint64_t kOffset = 1469598103934665603ull; - constexpr std::uint64_t kPrime = 1099511628211ull; - auto const* bytes = reinterpret_cast(&payload); - std::uint64_t hash = kOffset; - for (std::size_t i = 0; i < sizeof(PayloadT); ++i) { - hash ^= bytes[i]; - hash *= kPrime; - } - return hash; -} - -template -struct cagra_device_payload_owner { - struct state { - PayloadT host_payload{}; - PayloadT* device_payload{nullptr}; - cudaStream_t stream{}; - cudaEvent_t ready_event{}; - int device{-1}; - std::mutex mutex; - - explicit state(PayloadT payload) : host_payload(payload) {} - - ~state() noexcept - { - if (device_payload != nullptr) { - RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(device_payload, stream)); - } - if (ready_event != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(ready_event)); } - } - - PayloadT* dev_ptr(cudaStream_t cuda_stream) - { - std::lock_guard lock(mutex); - if (device_payload == nullptr) { - RAFT_CUDA_TRY(cudaGetDevice(&device)); - RAFT_CUDA_TRY(cudaMallocAsync( - reinterpret_cast(&device_payload), sizeof(PayloadT), cuda_stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync( - device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream)); - RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming)); - RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream)); - stream = cuda_stream; - } else { - RAFT_CUDA_TRY(cudaStreamWaitEvent(cuda_stream, ready_event, 0)); - } - return device_payload; - } - }; - - struct cache_key { - std::uint64_t payload_hash{}; - int device{}; - - bool operator==(cache_key const& other) const - { - return payload_hash == other.payload_hash && device == other.device; - } - }; - - struct cache_key_hash { - std::size_t operator()(cache_key const& key) const - { - auto seed = static_cast(key.payload_hash); - seed ^= static_cast(key.device) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - return seed; - } - }; - - cagra_device_payload_owner() = default; - - explicit cagra_device_payload_owner(PayloadT payload) : state_{std::make_shared(payload)} - { - } - - void* dev_ptr(cudaStream_t stream) const - { - if (state_ == nullptr) { return nullptr; } - - int device{}; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - - static std::mutex cache_mutex; - // Keep cached payload copies for process lifetime to avoid per-search allocation/copy churn. - // Cross-stream reuse is ordered by each state's ready_event before kernels consume the pointer. - static std::unordered_map>, cache_key_hash> cache; - - const auto key = cache_key{cagra_payload_hash(state_->host_payload), device}; - std::shared_ptr selected_state; - { - std::lock_guard lock(cache_mutex); - auto& entries = cache[key]; - for (auto const& cached : entries) { - if (std::memcmp(&cached->host_payload, &state_->host_payload, sizeof(PayloadT)) == 0) { - selected_state = cached; - break; - } - } - if (selected_state == nullptr) { - selected_state = state_; - entries.push_back(selected_state); - } - } - - state_ = std::move(selected_state); - return state_->dev_ptr(stream); - } - - PayloadT const* host_payload() const - { - return state_ == nullptr ? nullptr : &state_->host_payload; - } - - private: - mutable std::shared_ptr state_; -}; - -template -struct cagra_sample_filter_payload { - cagra_sample_filter payload{}; - cagra_device_payload_owner> storage_owner{}; - - cagra_sample_filter device_payload(cudaStream_t stream) const - { - auto out = payload; - if (out.filter_data == nullptr) { out.filter_data = storage_owner.dev_ptr(stream); } - return out; - } -}; - -template -struct is_bitset_filter : std::false_type {}; - -template -struct is_bitset_filter<::cuvs::neighbors::filtering::bitset_filter> - : std::true_type {}; - -template -struct is_udf_filter : std::false_type {}; - -template <> -struct is_udf_filter<::cuvs::neighbors::filtering::udf_filter> : std::true_type {}; - -template -void fill_cagra_sample_filter(cagra_sample_filter_payload& out, const FilterT& filter) -{ - using DecayedFilter = std::decay_t; - if constexpr (is_bitset_filter::value) { - const auto bitset_view = filter.view(); - out.storage_owner = cagra_device_payload_owner>{ - cagra_filter_data_storage{ - const_cast(bitset_view.data()), - static_cast(bitset_view.size()), - static_cast(bitset_view.get_original_nbits())}}; - } else if constexpr (is_udf_filter::value) { - out.payload.filter_data = filter.filter_data; - } -} - -/// Host: fill @ref cagra_sample_filter_payload from a CAGRA filter object. -template -cagra_sample_filter_payload extract_cagra_sample_filter( - const SampleFilterT& sample_filter) -{ - cagra_sample_filter_payload out; - if constexpr (requires { - sample_filter.filter; - sample_filter.offset; - }) { - out.payload.query_id_offset = sample_filter.offset; - fill_cagra_sample_filter(out, sample_filter.filter); - } else { - fill_cagra_sample_filter(out, sample_filter); - } - return out; -} - -/// Host: find UDF compile/link metadata only. Query offsets stay in the runtime payload produced -/// by @ref extract_cagra_sample_filter and are applied before calling the linked sample_filter. -template -const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( - const SampleFilterT& sample_filter) -{ - using DecayedFilter = std::decay_t; - if constexpr (is_udf_filter::value) { - return &sample_filter; - } else if constexpr (requires { sample_filter.filter; }) { - return get_cagra_udf_filter(sample_filter.filter); - } else { - return nullptr; - } -} - -#endif // !defined(__CUDACC_RTC__) - } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh index b1381c30f1..3e16cf2bcc 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh @@ -5,7 +5,7 @@ #pragma once -#include "cagra_filter_payload.cuh" +#include "../cagra_filter_payload.hpp" #include #include diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh index e590015b8e..8a673405b7 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh @@ -56,9 +56,8 @@ void select_and_run(const dataset_descriptor_host& dat SampleFilterT sample_filter, cudaStream_t stream) { - const auto filter_payload_owner = extract_cagra_sample_filter(sample_filter); - const auto filter_payload = filter_payload_owner.device_payload(stream); - const uint32_t query_id_offset = filter_payload.query_id_offset; + const auto filter_payload = extract_cagra_sample_filter(sample_filter, stream); + const uint32_t query_id_offset = filter_payload.query_id_offset; std::shared_ptr launcher = make_cagra_multi_cta_jit_launcher const& launcher) { - const auto filter_payload_owner = extract_cagra_sample_filter(sample_filter); - const auto filter_payload = filter_payload_owner.device_payload(cuda_stream); + const auto filter_payload = extract_cagra_sample_filter(sample_filter, cuda_stream); const auto block_size = 128; const auto teams_per_block = block_size / dataset_desc.team_size; @@ -161,8 +160,7 @@ void apply_filter_jit(const SourceIndexT* source_indices_ptr, cudaStream_t cuda_stream, std::shared_ptr const& launcher) { - const auto filter_payload_owner = extract_cagra_sample_filter(sample_filter); - const auto filter_payload = filter_payload_owner.device_payload(cuda_stream); + const auto filter_payload = extract_cagra_sample_filter(sample_filter, cuda_stream); const auto effective_query_id_offset = query_id_offset + filter_payload.query_id_offset; const std::uint32_t block_size = 256; diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh index d995788a98..aa4a78c513 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -93,19 +93,15 @@ std::uint64_t cagra_sample_filter_type_id(const SampleFilterT& sample_filter) template std::uint64_t cagra_sample_filter_hash(const SampleFilterT& sample_filter) { - const auto payload_owner = extract_cagra_sample_filter(sample_filter); - std::uint64_t seed = cagra_sample_filter_type_id(sample_filter); - if (const auto* storage = payload_owner.storage_owner.host_payload(); storage != nullptr) { - seed = cagra_hash_combine( - seed, static_cast(reinterpret_cast(storage->bitset_ptr))); - seed = cagra_hash_combine(seed, static_cast(storage->bitset_len)); - seed = cagra_hash_combine(seed, static_cast(storage->original_nbits)); - } - seed = cagra_hash_combine(seed, - static_cast( - reinterpret_cast(payload_owner.payload.filter_data))); - seed = - cagra_hash_combine(seed, static_cast(payload_owner.payload.query_id_offset)); + std::uint64_t seed = cagra_sample_filter_type_id(sample_filter); + seed = cagra_hash_combine( + seed, cagra_filter_payload_hash(sample_filter)); + seed = cagra_hash_combine( + seed, + static_cast(reinterpret_cast( + cagra_filter_data_ptr(sample_filter)))); + seed = cagra_hash_combine( + seed, static_cast(cagra_filter_query_id_offset(sample_filter))); seed = cagra_hash_combine(seed, cagra_udf_source_hash(sample_filter)); return seed; } @@ -506,7 +502,6 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn rmm::device_uvector hashmap; std::atomic> last_touch; uint64_t param_hash; - cagra_sample_filter_payload filter_payload_owner; cagra_sample_filter filter_payload; static inline auto calculate_parameter_hash( @@ -625,8 +620,7 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps)) { - filter_payload_owner = extract_cagra_sample_filter(sample_filter); - this->filter_payload = filter_payload_owner.device_payload(stream); + this->filter_payload = extract_cagra_sample_filter(sample_filter, stream); const uint32_t query_id_offset = filter_payload.query_id_offset; // set kernel launch parameters @@ -815,8 +809,7 @@ void select_and_run( const SourceIndexT* source_indices_ptr = source_indices.has_value() ? source_indices->data_handle() : nullptr; - const auto filter_payload_owner = extract_cagra_sample_filter(sample_filter); - const auto filter_payload = filter_payload_owner.device_payload(stream); + const auto filter_payload = extract_cagra_sample_filter(sample_filter, stream); const uint32_t query_id_offset = filter_payload.query_id_offset; // Use common logic to compute launch config diff --git a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp index 15f73d7b30..e5157ffa6a 100644 --- a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp +++ b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp @@ -8,8 +8,8 @@ #include #include -#include "../../sample_filter.cuh" // For none_sample_filter, bitset_filter -#include "jit_lto_kernels/cagra_filter_payload.cuh" // sample-filter payload helpers +#include "../../sample_filter.cuh" // For none_sample_filter, bitset_filter +#include "cagra_filter_payload.hpp" // sample-filter payload helpers #include #include @@ -116,7 +116,7 @@ struct sample_filter_jit_tag { cagra_jit_sample_filter_tag_type_always_false, "CAGRA JIT: sample_filter_jit_tag does not know how to link this filter. " "CagraSampleFilterWithQueryIdOffset requires Inner to be a supported " - "built-in filter or udf_filter (see cagra_filter_payload.cuh and " + "built-in filter or udf_filter (see cagra_filter_payload.hpp and " "sample_filter_utils.cuh). " "For a new filter kind, add a sample_filter_jit_tag branch. " "(SAMPLE_FILTER_T in error; check InnerFilter in compiler output.)");