diff --git a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp index cb33e4109b..a82b6bcc1e 100644 --- a/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/common_fragments.hpp @@ -13,6 +13,7 @@ struct tag_i8 {}; struct tag_u8 {}; struct tag_filter_none {}; struct tag_filter_bitset {}; +struct tag_filter_udf {}; struct tag_bitset_u32 {}; diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 8edbcab8fa..471807c0c7 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -346,7 +346,10 @@ struct search_params : cuvs::neighbors::search_params { /** * A parameter indicating the rate of nodes to be filtered-out, when filtering is used. * The value must be equal to or greater than 0.0 and less than 1.0. Default value is - * negative, in which case the filtering rate is automatically calculated. + * negative, in which case the filtering rate is automatically calculated when possible. + * For `filtering::udf_filter`, CAGRA uses `udf_filter::filtering_rate` when this value is + * negative. If both values are negative, CAGRA assumes 0.0 because a UDF's selectivity cannot be + * inferred from the source string. */ float filtering_rate = -1.0; }; diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 887593c23b..2fd804f115 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -24,7 +24,9 @@ #include #include +#include #include +#include #ifdef __cpp_lib_bitops #include @@ -495,7 +497,7 @@ namespace filtering { * @{ */ -enum class FilterType { None, Bitmap, Bitset }; +enum class FilterType { None, Bitmap, Bitset, UDF }; struct base_filter { ~base_filter() = default; @@ -615,6 +617,46 @@ struct bitset_filter : public base_filter { void to_csr(raft::resources const& handle, csr_matrix_t& csr); }; +/** + * @brief JIT-LTO user-defined filter predicate. + * + * The source must define a device function named by @c function_name with signature: + * + * @code{.cpp} + * __device__ bool cuvs_filter_udf(uint32_t query_id, source_index_t source_id, void* filter_data); + * @endcode + * + * Return @c true to allow a source vector to appear in the results and @c false to reject it. + * @c filter_data is passed through unchanged and must point to device-accessible memory when the + * UDF dereferences it. CAGRA currently provides @c source_index_t as @c uint32_t in the generated + * JIT fragment. + */ +struct udf_filter : public base_filter { + /** CUDA C++ source containing the device predicate. */ + std::string source; + /** Opaque device-accessible pointer passed to the predicate. */ + void* filter_data = nullptr; + /** Estimated fraction of rows rejected by the predicate, or negative if unknown. */ + float filtering_rate = -1.0f; + /** Device function name to call from the generated CAGRA sample filter. */ + std::string function_name = "cuvs_filter_udf"; + + udf_filter() = default; + + explicit udf_filter(std::string source, + void* filter_data = nullptr, + float filtering_rate = -1.0f, + std::string function_name = "cuvs_filter_udf") + : source(std::move(source)), + filter_data(filter_data), + filtering_rate(filtering_rate), + function_name(std::move(function_name)) + { + } + + FilterType get_filter_type() const override { return FilterType::UDF; } +}; + /** @} */ // end group neighbors_filtering /** diff --git a/cpp/src/neighbors/cagra.cuh b/cpp/src/neighbors/cagra.cuh index 73c3794d39..ee87c2c0ab 100644 --- a/cpp/src/neighbors/cagra.cuh +++ b/cpp/src/neighbors/cagra.cuh @@ -26,6 +26,8 @@ #include +#include + namespace cuvs::neighbors::cagra { // Member function implementations for cagra::index @@ -380,6 +382,25 @@ void search(raft::resources const& res, auto sample_filter_copy = sample_filter; return search_with_filtering( res, params_copy, idx, queries, neighbors, distances, sample_filter_copy); + } catch (const std::bad_cast&) { + } + + try { + auto& sample_filter = + dynamic_cast(sample_filter_ref); + search_params params_copy = params; + if (params.filtering_rate < 0.0) { + const float min_filtering_rate = 0.0f; + const float max_filtering_rate = 0.999f; + params_copy.filtering_rate = + sample_filter.filtering_rate < 0.0f + ? 0.0f + : std::min(std::max(sample_filter.filtering_rate, min_filtering_rate), + max_filtering_rate); + } + auto sample_filter_copy = sample_filter; + return search_with_filtering( + res, params_copy, idx, queries, neighbors, distances, sample_filter_copy); } catch (const std::bad_cast&) { RAFT_FAIL("Unsupported sample filter type"); } diff --git a/cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp b/cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp new file mode 100644 index 0000000000..6a5ef6514c --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp @@ -0,0 +1,256 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "../../sample_filter.cuh" // public filter types +#include "../sample_filter_data.cuh" +#include "jit_lto_kernels/cagra_filter_payload.cuh" + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +using cagra_filter_data_storage = ::cuvs::neighbors::detail::bitset_filter_data_t; + +template +std::uint64_t cagra_payload_hash(PayloadT const& payload) +{ + static_assert(std::is_trivially_copyable_v); + constexpr std::uint64_t kOffset = 1469598103934665603ull; + constexpr std::uint64_t kPrime = 1099511628211ull; + auto const* bytes = reinterpret_cast(&payload); + std::uint64_t hash = kOffset; + for (std::size_t i = 0; i < sizeof(PayloadT); ++i) { + hash ^= bytes[i]; + hash *= kPrime; + } + return hash; +} + +template +struct cagra_device_payload_owner { + struct state { + PayloadT host_payload{}; + PayloadT* device_payload{nullptr}; + cudaStream_t stream{}; + cudaEvent_t ready_event{}; + int device{-1}; + std::mutex mutex; + + explicit state(PayloadT payload) : host_payload(payload) {} + + ~state() noexcept + { + if (device_payload != nullptr) { + RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(device_payload, stream)); + } + if (ready_event != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(ready_event)); } + } + + PayloadT* dev_ptr(cudaStream_t cuda_stream) + { + std::lock_guard lock(mutex); + if (device_payload == nullptr) { + RAFT_CUDA_TRY(cudaGetDevice(&device)); + RAFT_CUDA_TRY(cudaMallocAsync( + reinterpret_cast(&device_payload), sizeof(PayloadT), cuda_stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync( + device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream)); + RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming)); + RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream)); + stream = cuda_stream; + } else { + RAFT_CUDA_TRY(cudaStreamWaitEvent(cuda_stream, ready_event, 0)); + } + return device_payload; + } + }; + + // PayloadT is copied to device by value. Pointer fields inside PayloadT are shallow-copied and + // must already point to device-addressable memory that remains valid while the cached payload is + // usable. + struct cache_key { + std::uint64_t payload_hash{}; + int device{}; + + bool operator==(cache_key const& other) const + { + return payload_hash == other.payload_hash && device == other.device; + } + }; + + struct cache_key_hash { + std::size_t operator()(cache_key const& key) const + { + auto seed = static_cast(key.payload_hash); + seed ^= static_cast(key.device) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; + } + }; + + cagra_device_payload_owner() = default; + + void* dev_ptr(PayloadT payload, cudaStream_t stream) const + { + int device{}; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + + // Keep cached payload copies for process lifetime to avoid per-search allocation/copy churn. + // Cross-stream reuse is ordered by each state's ready_event before kernels consume the pointer. + const auto key = cache_key{cagra_payload_hash(payload), device}; + std::shared_ptr selected_state; + { + std::lock_guard lock(cache_mutex_); + auto& entries = cache_[key]; + for (auto const& cached : entries) { + if (std::memcmp(&cached->host_payload, &payload, sizeof(PayloadT)) == 0) { + selected_state = cached; + break; + } + } + if (selected_state == nullptr) { + selected_state = std::make_shared(payload); + entries.push_back(selected_state); + } + } + + return selected_state->dev_ptr(stream); + } + + private: + mutable std::mutex cache_mutex_; + mutable std::unordered_map>, cache_key_hash> cache_; +}; + +template +struct is_bitset_filter : std::false_type {}; + +template +struct is_bitset_filter<::cuvs::neighbors::filtering::bitset_filter> + : std::true_type {}; + +template +struct is_udf_filter : std::false_type {}; + +template <> +struct is_udf_filter<::cuvs::neighbors::filtering::udf_filter> : std::true_type {}; + +template +cagra_filter_data_storage make_cagra_filter_data_storage(const FilterT& filter) +{ + const auto bitset_view = filter.view(); + return cagra_filter_data_storage{ + const_cast(bitset_view.data()), + static_cast(bitset_view.size()), + static_cast(bitset_view.get_original_nbits())}; +} + +template +void* get_cagra_device_payload(PayloadT payload, cudaStream_t stream) +{ + static cagra_device_payload_owner owner; + return owner.dev_ptr(payload, stream); +} + +template +void fill_cagra_sample_filter(cagra_sample_filter& out, + const FilterT& filter, + cudaStream_t stream) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_bitset_filter::value) { + out.filter_data = get_cagra_device_payload(make_cagra_filter_data_storage(filter), + stream); + } else if constexpr (is_udf_filter::value) { + out.filter_data = filter.filter_data; + } +} + +template +std::uint64_t cagra_filter_payload_hash(const FilterT& filter) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_bitset_filter::value) { + return cagra_payload_hash(make_cagra_filter_data_storage(filter)); + } else if constexpr (requires { filter.filter; }) { + return cagra_filter_payload_hash(filter.filter); + } else { + return 0; + } +} + +template +void* cagra_filter_data_ptr(const FilterT& filter) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_udf_filter::value) { + return filter.filter_data; + } else if constexpr (requires { filter.filter; }) { + return cagra_filter_data_ptr(filter.filter); + } else { + return nullptr; + } +} + +template +std::uint32_t cagra_filter_query_id_offset(const SampleFilterT& sample_filter) +{ + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + return sample_filter.offset; + } else { + return 0; + } +} + +/// Host: fill @ref cagra_sample_filter from a CAGRA filter object. +template +cagra_sample_filter extract_cagra_sample_filter(const SampleFilterT& sample_filter, + cudaStream_t stream) +{ + cagra_sample_filter out; + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + out.query_id_offset = sample_filter.offset; + fill_cagra_sample_filter(out, sample_filter.filter, stream); + } else { + fill_cagra_sample_filter(out, sample_filter, stream); + } + return out; +} + +/// Host: find UDF compile/link metadata only. Query offsets stay in the runtime payload produced +/// by @ref extract_cagra_sample_filter and are applied before calling the linked sample_filter. +template +const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( + const SampleFilterT& sample_filter) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_udf_filter::value) { + return &sample_filter; + } else if constexpr (requires { sample_filter.filter; }) { + return get_cagra_udf_filter(sample_filter.filter); + } else { + return nullptr; + } +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in index 49d5d2fa07..5829072631 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in @@ -8,10 +8,10 @@ namespace { -using index_t = @index_type@; -using distance_t = @distance_type@; -using source_index_t = @source_index_type@; -using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; +using index_t = @index_type@; +using distance_t = @distance_type@; +using source_index_t = @source_index_type@; +using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace @@ -24,7 +24,7 @@ extern "C" __global__ void apply_filter_kernel(const source_index_t* const sourc const std::uint32_t result_buffer_size, const std::uint32_t num_queries, const std::uint32_t query_id_offset, - cagra_bitset_t bitset) + cagra_sample_filter_t filter_payload) { apply_filter_kernel_jit(source_indices_ptr, result_indices_ptr, @@ -33,7 +33,7 @@ extern "C" __global__ void apply_filter_kernel(const source_index_t* const sourc result_buffer_size, num_queries, query_id_offset, - bitset); + filter_payload); } static_assert(std::is_same_v -#include - -namespace cuvs::neighbors::cagra::detail { - -template -using cagra_bitset = cuvs::neighbors::detail::bitset_filter_data_t; - -/// Host: bitset payload for kernels plus query offset for wrapped CAGRA filters. -template -struct cagra_sample_filter { - cagra_bitset bitset{}; - std::uint32_t query_id_offset{0}; -}; - -template -struct is_bitset_filter : std::false_type {}; - -template -struct is_bitset_filter> - : std::true_type {}; - -/// Host: fill @ref cagra_sample_filter from a CAGRA filter object (used by JIT LTO launchers). -template -cagra_sample_filter extract_cagra_sample_filter(const SampleFilterT& sample_filter) -{ - cagra_sample_filter out; - if constexpr (requires { - sample_filter.filter; - sample_filter.offset; - }) { - out.query_id_offset = sample_filter.offset; - using InnerFilter = decltype(sample_filter.filter); - if constexpr (is_bitset_filter::value) { - const auto bitset_view = sample_filter.filter.view(); - out.bitset.bitset_ptr = const_cast(bitset_view.data()); - out.bitset.bitset_len = static_cast(bitset_view.size()); - out.bitset.original_nbits = static_cast(bitset_view.get_original_nbits()); - } - } - return out; -} - -} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh new file mode 100644 index 0000000000..f4e24ad2dc --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh @@ -0,0 +1,20 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include + +namespace cuvs::neighbors::cagra::detail { + +/// Device payload for linked CAGRA sample filters plus query offset for wrapped filters. +template +struct cagra_sample_filter { + void* filter_data{nullptr}; + std::uint32_t query_id_offset{0}; + + __device__ __forceinline__ void* sample_filter_data() { return filter_data; } +}; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp index 60d17c5128..896817d869 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp @@ -7,6 +7,7 @@ #include "../compute_distance.hpp" #include "../shared_launcher_jit.hpp" +#include "sample_filter_udf.cuh" #include "search_multi_cta_planner.hpp" #include "search_multi_kernel_planner.hpp" #include "search_single_cta_planner.hpp" @@ -37,7 +38,8 @@ std::shared_ptr build_single_cta_launcher( const dataset_descriptor_host& dataset_desc, bool topk_by_bitonic_sort, bool bitonic_sort_and_merge_multi_warps, - bool persistent) + bool persistent, + std::unique_ptr sample_filter_udf_fragment) { single_cta_search::CagraSingleCtaSearchPlanner build_single_cta_launcher( } planner.add_search_kernel_fragment( topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); - planner.add_sample_filter_device_function(); + planner.add_sample_filter_device_function(std::move(sample_filter_udf_fragment)); return planner.get_launcher(); } @@ -85,7 +87,8 @@ template std::shared_ptr build_multi_cta_launcher( - const dataset_descriptor_host& dataset_desc) + const dataset_descriptor_host& dataset_desc, + std::unique_ptr sample_filter_udf_fragment) { multi_cta_search::CagraMultiCtaSearchPlanner build_multi_cta_launcher( dataset_desc.metric, dataset_desc.team_size, dataset_desc.dataset_block_dim); } planner.add_search_multi_cta_kernel_fragment(); - planner.add_sample_filter_device_function(); + planner.add_sample_filter_device_function(std::move(sample_filter_udf_fragment)); return planner.get_launcher(); } @@ -130,7 +133,8 @@ template std::shared_ptr build_multi_kernel_launcher( const dataset_descriptor_host& dataset_desc, - const char* linked_kernel_name) + const char* linked_kernel_name, + std::unique_ptr sample_filter_udf_fragment) { multi_kernel_search::CagraMultiKernelSearchPlanner build_multi_kernel_launcher( planner.add_compute_distance_device_function( dataset_desc.metric, dataset_desc.team_size, dataset_desc.dataset_block_dim); } - planner.add_sample_filter_device_function(); + planner.add_sample_filter_device_function(std::move(sample_filter_udf_fragment)); planner.add_linked_kernel(linked_kernel_name); return planner.get_launcher(); } @@ -175,7 +179,8 @@ template -std::shared_ptr build_apply_filter_only_launcher() +std::shared_ptr build_apply_filter_only_launcher( + std::unique_ptr sample_filter_udf_fragment) { multi_kernel_search::CagraMultiKernelSearchPlanner build_apply_filter_only_launcher() CodebookTag, SampleFilterJitTag> planner("apply_filter_kernel"); - planner.add_sample_filter_device_function(); + planner.add_sample_filter_device_function(std::move(sample_filter_udf_fragment)); planner.add_linked_kernel("apply_filter_kernel"); return planner.get_launcher(); } @@ -204,7 +209,8 @@ std::shared_ptr make_cagra_single_cta_jit_launcher( const dataset_descriptor_host& dataset_desc, bool topk_by_bitonic_sort, bool bitonic_sort_and_merge_multi_warps, - bool persistent) + bool persistent, + std::unique_ptr sample_filter_udf_fragment = nullptr) { using DataTag = decltype(get_data_type_tag()); using IndexTag = decltype(get_index_type_tag()); @@ -225,7 +231,11 @@ std::shared_ptr make_cagra_single_cta_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + dataset_desc, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + persistent, + std::move(sample_filter_udf_fragment)); } using CodebookTag = codebook_tag_standard_t; if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { @@ -242,7 +252,11 @@ std::shared_ptr make_cagra_single_cta_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + dataset_desc, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + persistent, + std::move(sample_filter_udf_fragment)); } using QueryTag = query_type_tag_standard_t; return cagra_jit_launcher_factory_detail::build_single_cta_launcher make_cagra_single_cta_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, persistent); + dataset_desc, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + persistent, + std::move(sample_filter_udf_fragment)); } /// Build a JIT AlgorithmLauncher for multi-CTA CAGRA search. @@ -266,7 +284,8 @@ template std::shared_ptr make_cagra_multi_cta_jit_launcher( - const dataset_descriptor_host& dataset_desc) + const dataset_descriptor_host& dataset_desc, + std::unique_ptr sample_filter_udf_fragment = nullptr) { using DataTag = decltype(get_data_type_tag()); using IndexTag = decltype(get_index_type_tag()); @@ -286,7 +305,8 @@ std::shared_ptr make_cagra_multi_cta_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(dataset_desc); + SourceIndexT>( + dataset_desc, std::move(sample_filter_udf_fragment)); } using CodebookTag = codebook_tag_standard_t; if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { @@ -302,7 +322,8 @@ std::shared_ptr make_cagra_multi_cta_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(dataset_desc); + SourceIndexT>( + dataset_desc, std::move(sample_filter_udf_fragment)); } using QueryTag = query_type_tag_standard_t; return cagra_jit_launcher_factory_detail::build_multi_cta_launcher make_cagra_multi_cta_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(dataset_desc); + SourceIndexT>( + dataset_desc, std::move(sample_filter_udf_fragment)); } /// Build a JIT AlgorithmLauncher for multi-kernel CAGRA helpers that need `setup_workspace` and @@ -331,7 +353,8 @@ template std::shared_ptr make_cagra_multi_kernel_jit_launcher( const dataset_descriptor_host& dataset_desc, - const char* linked_kernel_name) + const char* linked_kernel_name, + std::unique_ptr sample_filter_udf_fragment = nullptr) { using DataTag = decltype(get_data_type_tag()); using IndexTag = decltype(get_index_type_tag()); @@ -352,7 +375,7 @@ std::shared_ptr make_cagra_multi_kernel_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, linked_kernel_name); + dataset_desc, linked_kernel_name, std::move(sample_filter_udf_fragment)); } using CodebookTag = codebook_tag_standard_t; if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { @@ -369,7 +392,7 @@ std::shared_ptr make_cagra_multi_kernel_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, linked_kernel_name); + dataset_desc, linked_kernel_name, std::move(sample_filter_udf_fragment)); } using QueryTag = query_type_tag_standard_t; return cagra_jit_launcher_factory_detail::build_multi_kernel_launcher make_cagra_multi_kernel_jit_launcher( IndexT, DistanceT, SourceIndexT>( - dataset_desc, linked_kernel_name); + dataset_desc, linked_kernel_name, std::move(sample_filter_udf_fragment)); } /// JIT launcher for the post-search `apply_filter_kernel` only (no workspace / distance fragments). @@ -396,7 +419,8 @@ template std::shared_ptr make_cagra_apply_filter_jit_launcher( - const dataset_descriptor_host& dataset_desc) + const dataset_descriptor_host& dataset_desc, + std::unique_ptr sample_filter_udf_fragment = nullptr) { using DataTag = decltype(get_data_type_tag()); using IndexTag = decltype(get_index_type_tag()); @@ -416,7 +440,8 @@ std::shared_ptr make_cagra_apply_filter_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(); + SourceIndexT>( + std::move(sample_filter_udf_fragment)); } using CodebookTag = codebook_tag_standard_t; if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { @@ -432,7 +457,8 @@ std::shared_ptr make_cagra_apply_filter_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(); + SourceIndexT>( + std::move(sample_filter_udf_fragment)); } using QueryTag = query_type_tag_standard_t; return cagra_jit_launcher_factory_detail::build_apply_filter_only_launcher make_cagra_apply_filter_jit_launcher( DataT, IndexT, DistanceT, - SourceIndexT>(); + SourceIndexT>( + std::move(sample_filter_udf_fragment)); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp index 317ca1a1b6..0c3ed64d13 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -9,9 +9,11 @@ #include #include #include +#include #include #include +#include #include #include @@ -250,9 +252,16 @@ struct CagraPlannerBase : AlgorithmPlanner { static_cast(dataset_block_dim)); } - void add_sample_filter_device_function() + void add_sample_filter_device_function(std::unique_ptr udf_fragment = nullptr) { - if constexpr (!std::is_same_v) { + if constexpr (std::is_same_v) { + RAFT_EXPECTS(udf_fragment == nullptr, "Unexpected CAGRA sample-filter UDF fragment"); + } else if constexpr (std::is_same_v) { + RAFT_EXPECTS(udf_fragment != nullptr, "CAGRA UDF filter requires a JIT-LTO fragment"); + this->add_fragment(std::move(udf_fragment)); + } else { + RAFT_EXPECTS(udf_fragment == nullptr, "Built-in CAGRA sample filters use static fragments"); this->add_static_fragment>(); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in index 0035ae1da0..d0cc88bc74 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in @@ -14,7 +14,7 @@ using distance_t = @distance_type@; using source_index_t = @source_index_type@; using dataset_desc_base = cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; -using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; +using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace @@ -35,7 +35,7 @@ extern "C" __global__ void compute_distance_to_child_nodes(const index_t* const index_t* const result_indices_ptr, distance_t* const result_distances_ptr, const std::uint32_t ldd, - cagra_bitset_t bitset) + cagra_sample_filter_t filter_payload) { compute_distance_to_child_nodes_kernel_jit( parent_node_list, @@ -53,7 +53,7 @@ extern "C" __global__ void compute_distance_to_child_nodes(const index_t* const result_indices_ptr, result_distances_ptr, ldd, - bitset); + filter_payload); } static_assert( diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp index 370dbd33d8..c9702d3b72 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp @@ -12,7 +12,7 @@ #include #include "../compute_distance.hpp" // dataset_descriptor_base_t -#include "cagra_bitset.cuh" +#include "cagra_filter_payload.cuh" #include "search_single_cta_device_helpers.cuh" namespace cuvs::neighbors::cagra::detail { @@ -47,7 +47,7 @@ using search_single_cta_kernel_func_t = const std::uint32_t, const dataset_descriptor_base_t*, const IndexT, - cagra_bitset); + cagra_sample_filter); namespace single_cta_search { @@ -76,7 +76,7 @@ using search_single_cta_p_kernel_func_t = const std::uint32_t, const std::uint32_t, const dataset_descriptor_base_t*, - cagra_bitset); + cagra_sample_filter); } // namespace single_cta_search @@ -105,7 +105,7 @@ using search_multi_cta_kernel_func_t = std::uint32_t* const, const IndexT, const std::uint32_t, - cagra_bitset); + cagra_sample_filter); } // namespace multi_cta_search @@ -143,7 +143,7 @@ using compute_distance_to_child_nodes_kernel_func_t = IndexT* const, DistanceT* const, const std::uint32_t, - cagra_bitset); + cagra_sample_filter); template using apply_filter_kernel_func_t = void(const SourceIndexT* const, @@ -153,7 +153,7 @@ using apply_filter_kernel_func_t = void(const SourceIndexT* const, const std::uint32_t, const std::uint32_t, const std::uint32_t, - cagra_bitset); + cagra_sample_filter); } // namespace multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh new file mode 100644 index 0000000000..3e16cf2bcc --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh @@ -0,0 +1,89 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../cagra_filter_payload.hpp" + +#include +#include + +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +inline constexpr const char* cagra_udf_source_index_type_name() +{ + static_assert(std::is_same_v, + "CAGRA filter UDFs currently support SourceIndexT = uint32_t only"); + return "uint32_t"; +} + +inline std::string instantiate_cagra_sample_filter_udf(std::string const& user_source, + std::string const& function_name, + const char* source_index_type) +{ + std::ostringstream oss; + oss << R"( +using int8_t = signed char; +using uint8_t = unsigned char; +using int32_t = int; +using uint32_t = unsigned int; +using int64_t = long long; +using uint64_t = unsigned long long; +using source_index_t = )" + << source_index_type << R"(; + +namespace cuvs::neighbors::detail { +template +extern __device__ bool sample_filter(uint32_t query_id, SourceIndexT node_id, void* filter_data); +} // namespace cuvs::neighbors::detail + +)"; + oss << user_source << R"( + +namespace cuvs::neighbors::detail { + +template <> +__device__ bool sample_filter(uint32_t query_id, + source_index_t node_id, + void* filter_data) +{ + return )" + << function_name << R"((query_id, node_id, filter_data); +} + +} // namespace cuvs::neighbors::detail +)"; + return oss.str(); +} + +template +std::unique_ptr make_cagra_sample_filter_udf_fragment( + const SampleFilterT& sample_filter) +{ + const auto* udf = get_cagra_udf_filter(sample_filter); + if (udf == nullptr) { return nullptr; } + + RAFT_EXPECTS(!udf->source.empty(), "CAGRA filter UDF source must not be empty"); + RAFT_EXPECTS(!udf->function_name.empty(), "CAGRA filter UDF function name must not be empty"); + + auto code = instantiate_cagra_sample_filter_udf( + udf->source, udf->function_name, cagra_udf_source_index_type_name()); + std::string key = "cagra_sample_filter_udf:"; + key += cagra_udf_source_index_type_name(); + key += ":"; + key += udf->function_name; + key += ":"; + key += code; + return nvrtc_compiler().compile(key, code); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh index 492195f359..4c4f2e4f62 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh @@ -20,7 +20,7 @@ #include #endif -#include "cagra_bitset.cuh" +#include "cagra_filter_payload.cuh" #include "device_common_jit.cuh" #include "extern_device_functions.cuh" @@ -52,7 +52,7 @@ __device__ void search_kernel_jit( uint32_t* const num_executed_iterations, /* stats */ const IndexT graph_size, const uint32_t query_id_offset, // Offset to add to query_id when calling filter - cagra_bitset bitset) + cagra_sample_filter filter_payload) { using DATA_T = DataT; using INDEX_T = IndexT; @@ -261,7 +261,7 @@ __device__ void search_kernel_jit( const auto parent_id = result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; if (!sample_filter(query_id + query_id_offset, to_source_index(parent_id), - bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + filter_payload.sample_filter_data())) { // If the parent must not be in the resulting top-k list, remove from the parent list result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); result_indices_buffer[parent_indices_buffer[p]] = invalid_index; @@ -280,7 +280,7 @@ __device__ void search_kernel_jit( index &= ~index_msb_1_mask; if (!sample_filter(query_id + query_id_offset, to_source_index(index), - bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + filter_payload.sample_filter_data())) { result_indices_buffer[i] = invalid_index; result_distances_buffer[i] = utils::get_max_value(); } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in index 9acd73687b..1d22bd5449 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in @@ -14,7 +14,7 @@ using distance_t = @distance_type@; using source_index_t = @source_index_type@; using dataset_desc_base = cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; -using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; +using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace @@ -42,7 +42,7 @@ extern "C" __global__ __launch_bounds__(1024, 1) void search_multi_cta( std::uint32_t* const num_executed_iterations, const index_t graph_size, const std::uint32_t query_id_offset, - cagra_bitset_t bitset) + cagra_sample_filter_t filter_payload) { search_kernel_jit(result_indices_ptr, result_distances_ptr, @@ -65,7 +65,7 @@ extern "C" __global__ __launch_bounds__(1024, 1) void search_multi_cta( num_executed_iterations, graph_size, query_id_offset, - bitset); + filter_payload); } static_assert( diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh index a7ab078343..a714ced5c2 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh @@ -8,7 +8,7 @@ #include "../../neighbors_device_intrinsics.cuh" #include "../hashmap.hpp" #include "../utils.hpp" -#include "cagra_bitset.cuh" +#include "cagra_filter_payload.cuh" #include #include @@ -105,7 +105,7 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( IndexT* const result_indices_ptr, // [num_queries, ldd] DistanceT* const result_distances_ptr, // [num_queries, ldd] const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree - cagra_bitset bitset) + cagra_sample_filter filter_payload) { using INDEX_T = IndexT; using DISTANCE_T = DistanceT; @@ -113,15 +113,16 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( // Get team_size_bits directly from base descriptor uint32_t team_size_bits = dataset_desc->team_size_bitshift(); - const auto team_size = 1u << team_size_bits; - const uint32_t ldb = hashmap::get_size(hash_bitlen); - const auto tid = threadIdx.x + blockDim.x * blockIdx.x; - const auto global_team_id = tid >> team_size_bits; - const auto query_id = blockIdx.y; + const auto team_size = 1u << team_size_bits; + const uint32_t ldb = hashmap::get_size(hash_bitlen); + const auto tid = threadIdx.x + blockDim.x * blockIdx.x; + const auto global_team_id = tid >> team_size_bits; + const auto local_query_id = blockIdx.y; + const auto filter_query_id = filter_payload.query_id_offset + local_query_id; extern __shared__ uint8_t smem[]; auto smem_desc = - setup_workspace(dataset_desc, smem, query_ptr, query_id); + setup_workspace(dataset_desc, smem, query_ptr, local_query_id); __syncthreads(); if (global_team_id >= search_width * graph_degree) { return; } @@ -132,10 +133,11 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( if (parent_list_index == utils::get_max_value()) { return; } constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const auto raw_parent_index = parent_candidates_ptr[parent_list_index + (lds * query_id)]; + const auto raw_parent_index = parent_candidates_ptr[parent_list_index + (lds * local_query_id)]; if (raw_parent_index == utils::get_max_value()) { - result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + result_distances_ptr[ldd * local_query_id + global_team_id] = + utils::get_max_value(); return; } const auto parent_index = raw_parent_index & ~index_msb_1_mask; @@ -153,24 +155,24 @@ __device__ void compute_distance_to_child_nodes_kernel_jit( if (compute_distance_flag) { if ((threadIdx.x & (team_size - 1)) == 0) { - result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; - result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2; + result_indices_ptr[ldd * local_query_id + global_team_id] = child_id; + result_distances_ptr[ldd * local_query_id + global_team_id] = norm2; } } else { if ((threadIdx.x & (team_size - 1)) == 0) { - result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + result_distances_ptr[ldd * local_query_id + global_team_id] = + utils::get_max_value(); } } - if (bitset.bitset_ptr != nullptr) { - const SourceIndexT node_id = source_indices_ptr == nullptr - ? static_cast(parent_index) - : static_cast(source_indices_ptr[parent_index]); - if (!sample_filter(query_id, node_id, &bitset)) { - parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); - parent_distance_ptr[parent_list_index + (lds * query_id)] = - utils::get_max_value(); - } + const SourceIndexT node_id = source_indices_ptr == nullptr + ? static_cast(parent_index) + : static_cast(source_indices_ptr[parent_index]); + if (!sample_filter(filter_query_id, node_id, filter_payload.sample_filter_data())) { + parent_candidates_ptr[parent_list_index + (lds * local_query_id)] = + utils::get_max_value(); + parent_distance_ptr[parent_list_index + (lds * local_query_id)] = + utils::get_max_value(); } } @@ -185,7 +187,7 @@ __device__ void apply_filter_kernel_jit( const std::uint32_t result_buffer_size, const std::uint32_t num_queries, const std::uint32_t query_id_offset, - cagra_bitset bitset) + cagra_sample_filter filter_payload) { constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; const auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -195,14 +197,14 @@ __device__ void apply_filter_kernel_jit( const auto index = i + j * lds; if (result_indices_ptr[index] != ~index_msb_1_mask) { - // Use extern sample_filter function with 3 params: query_id, node_id, filter_data - // Third argument is &bitset (layout matches bitset_filter_data_t) or nullptr for none_filter + // Use extern sample_filter function with 3 params: query_id, node_id, filter_data. + // The payload maps built-in bitset filters and UDF context pointers to the linked ABI. SourceIndexT node_id = source_indices_ptr == nullptr ? static_cast(result_indices_ptr[index]) : source_indices_ptr[result_indices_ptr[index]]; if (!sample_filter( - query_id_offset + j, node_id, bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + query_id_offset + j, node_id, filter_payload.sample_filter_data())) { result_indices_ptr[index] = utils::get_max_value(); result_distances_ptr[index] = utils::get_max_value(); } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh index 282490559a..d481c44946 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh @@ -29,7 +29,7 @@ #endif // Include extern function declarations before namespace so they're available to kernel definitions -#include "cagra_bitset.cuh" +#include "cagra_filter_payload.cuh" #include "extern_device_functions.cuh" // Include shared JIT device functions #include "device_common_jit.cuh" @@ -78,7 +78,7 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core( const std::uint32_t query_id, const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter const dataset_descriptor_base_t* dataset_desc, - cagra_bitset bitset, + cagra_sample_filter filter_payload, const IndexT graph_size = 0) // Original number of bits { using LOAD_T = device::LOAD_128BIT_T; @@ -306,7 +306,7 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core( const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; if (!sample_filter(query_id + query_id_offset, to_source_index(parent_id), - bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + filter_payload.sample_filter_data())) { result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); result_indices_buffer[parent_list_buffer[p]] = invalid_index; *filter_flag = 1; @@ -327,7 +327,7 @@ RAFT_DEVICE_INLINE_FUNCTION void search_core( if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id + query_id_offset, to_source_index(node_id), - bitset.bitset_ptr != nullptr ? &bitset : nullptr)) { + filter_payload.sample_filter_data())) { result_distances_buffer[i] = utils::get_max_value(); result_indices_buffer[i] = invalid_index; } @@ -483,7 +483,7 @@ __device__ void search_kernel_jit( const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter const dataset_descriptor_base_t* dataset_desc, const IndexT graph_size, - cagra_bitset bitset) + cagra_sample_filter filter_payload) { const auto query_id = blockIdx.y; search_core* dataset_desc, - cagra_bitset bitset) + cagra_sample_filter filter_payload) { using job_desc_type = job_desc_t>; __shared__ typename job_desc_type::input_t job_descriptor; @@ -625,7 +625,7 @@ __device__ void search_single_cta_p_impl( query_id, query_id_offset, dataset_desc, - bitset); + filter_payload); // make sure all writes are visible even for the host // (e.g. when result buffers are in pinned memory) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in index 26b363a90a..40e7b15a6f 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in @@ -17,7 +17,7 @@ using distance_t = @distance_type@; using source_index_t = @source_index_type@; using dataset_desc_base = cuvs::neighbors::cagra::detail::dataset_descriptor_base_t; -using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; +using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace @@ -49,7 +49,7 @@ extern "C" __global__ __launch_bounds__(1024, 1) void search_single_cta( const std::uint32_t query_id_offset, const dataset_desc_base* dataset_desc, const index_t graph_size, - cagra_bitset_t bitset) + cagra_sample_filter_t filter_payload) { single_cta_search::search_kernel_jit; using job_descriptor_batch = scta_jit::job_desc_t>; -using cagra_bitset_t = cuvs::neighbors::cagra::detail::cagra_bitset; +using cagra_sample_filter_t = cuvs::neighbors::cagra::detail::cagra_sample_filter; } // namespace @@ -51,7 +51,7 @@ extern "C" __global__ __launch_bounds__(1024, 1) void search_single_cta_p( const std::uint32_t small_hash_reset_interval, const std::uint32_t query_id_offset, const dataset_desc_base* dataset_desc, - cagra_bitset_t bitset) + cagra_sample_filter_t filter_payload) { search_single_cta_p_impl>; +using udf_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::udf_filter>; } // namespace @@ -20,5 +22,6 @@ instantiate_kernel_selection(data_t, float, cuvs::neighbors::filtering::none_sample_filter); instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t); +instantiate_kernel_selection(data_t, uint32_t, float, udf_filter_t); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh index 663f8a4559..8a673405b7 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh @@ -56,15 +56,16 @@ void select_and_run(const dataset_descriptor_host& dat SampleFilterT sample_filter, cudaStream_t stream) { - const auto bf = extract_cagra_sample_filter(sample_filter); - const uint32_t query_id_offset = bf.query_id_offset; + const auto filter_payload = extract_cagra_sample_filter(sample_filter, stream); + const uint32_t query_id_offset = filter_payload.query_id_offset; std::shared_ptr launcher = make_cagra_multi_cta_jit_launcher>(dataset_desc); + sample_filter_jit_tag_t>( + dataset_desc, make_cagra_sample_filter_udf_fragment(sample_filter)); if (!launcher) { RAFT_FAIL("Failed to get JIT launcher"); } @@ -142,7 +143,7 @@ void select_and_run(const dataset_descriptor_host& dat num_executed_iterations, static_cast(graph.extent(0)), query_id_offset, - bf.bitset); + filter_payload); }; cuvs::neighbors::detail::safely_launch_kernel_with_smem_size< multi_cta_search::search_multi_cta_kernel_func_t>( diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index f6dfcb8aee..e09ef82a39 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -519,7 +519,9 @@ struct search DISTANCE_T, SourceIndexT, sample_filter_jit_tag_t>( - dataset_desc, "compute_distance_to_child_nodes"); + dataset_desc, + "compute_distance_to_child_nodes", + make_cagra_sample_filter_udf_fragment(sample_filter)); unsigned iter = 0; while (1) { @@ -615,7 +617,7 @@ struct search DISTANCE_T, SourceIndexT, sample_filter_jit_tag_t>( - dataset_desc); + dataset_desc, make_cagra_sample_filter_udf_fragment(sample_filter)); apply_filter_jit( source_indices_ptr, diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh index d0e3db2c99..bc341b9082 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh @@ -13,7 +13,7 @@ #include "jit_lto_kernels/kernel_def.hpp" #include "jit_lto_kernels/search_multi_kernel_planner.hpp" #include "search_plan.cuh" // For search_params -#include "shared_launcher_jit.hpp" // cagra_bitset / cagra_sample_filter, sample_filter_jit_tag_t, tags +#include "shared_launcher_jit.hpp" // sample-filter payload helpers and JIT tags #include #include #include @@ -111,7 +111,7 @@ void compute_distance_to_child_nodes_jit( cudaStream_t cuda_stream, std::shared_ptr const& launcher) { - const auto bf = extract_cagra_sample_filter(sample_filter); + const auto filter_payload = extract_cagra_sample_filter(sample_filter, cuda_stream); const auto block_size = 128; const auto teams_per_block = block_size / dataset_desc.team_size; @@ -142,7 +142,7 @@ void compute_distance_to_child_nodes_jit( result_indices_ptr, result_distances_ptr, ldd, - bf.bitset); + filter_payload); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -160,9 +160,8 @@ void apply_filter_jit(const SourceIndexT* source_indices_ptr, cudaStream_t cuda_stream, std::shared_ptr const& launcher) { - // Note: query_id for the linked filter is the function's `query_id_offset` + query index, not - // the wrapper's offset; we only need bitset pointers (same as other JIT launchers). - const auto bf = extract_cagra_sample_filter(sample_filter); + const auto filter_payload = extract_cagra_sample_filter(sample_filter, cuda_stream); + const auto effective_query_id_offset = query_id_offset + filter_payload.query_id_offset; const std::uint32_t block_size = 256; const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); @@ -181,8 +180,8 @@ void apply_filter_jit(const SourceIndexT* source_indices_ptr, lds, result_buffer_size, num_queries, - query_id_offset, - bf.bitset); + effective_query_id_offset, + filter_payload); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in index 85342e7093..4616a9652b 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in @@ -11,6 +11,8 @@ namespace { using data_t = @data_type@; using bitset_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< cuvs::neighbors::filtering::bitset_filter>; +using udf_filter_t = cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset< + cuvs::neighbors::filtering::udf_filter>; } // namespace @@ -20,5 +22,6 @@ instantiate_kernel_selection(data_t, float, cuvs::neighbors::filtering::none_sample_filter); instantiate_kernel_selection(data_t, uint32_t, float, bitset_filter_t); +instantiate_kernel_selection(data_t, uint32_t, float, udf_filter_t); } // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh index 96db9a743b..aa4a78c513 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -45,6 +45,7 @@ #include #include #include +#include #include #include #include @@ -57,6 +58,54 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search { +inline std::uint64_t cagra_hash_combine(std::uint64_t seed, std::uint64_t value) +{ + return seed ^ (value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2)); +} + +template +std::uint64_t cagra_udf_source_hash(const SampleFilterT& sample_filter) +{ + if (const auto* udf = get_cagra_udf_filter(sample_filter); udf != nullptr) { + std::uint64_t seed = 0; + seed = cagra_hash_combine(seed, std::hash{}(udf->source)); + seed = cagra_hash_combine(seed, std::hash{}(udf->function_name)); + return seed; + } + return 0; +} + +template +std::uint64_t cagra_sample_filter_type_id(const SampleFilterT& sample_filter) +{ + using DecayedFilter = std::decay_t; + if constexpr (is_udf_filter::value) { + return 2; + } else if constexpr (is_bitset_filter::value) { + return 1; + } else if constexpr (requires { sample_filter.filter; }) { + return cagra_sample_filter_type_id(sample_filter.filter); + } else { + return 0; + } +} + +template +std::uint64_t cagra_sample_filter_hash(const SampleFilterT& sample_filter) +{ + std::uint64_t seed = cagra_sample_filter_type_id(sample_filter); + seed = cagra_hash_combine( + seed, cagra_filter_payload_hash(sample_filter)); + seed = cagra_hash_combine( + seed, + static_cast(reinterpret_cast( + cagra_filter_data_ptr(sample_filter)))); + seed = cagra_hash_combine( + seed, static_cast(cagra_filter_query_id_offset(sample_filter))); + seed = cagra_hash_combine(seed, cagra_udf_source_hash(sample_filter)); + return seed; +} + // Persistent queues / runner (host). worker_handle_t, job_desc_t, kCacheLineBytes, k* job limits: // `jit_lto_kernels/search_single_cta_device_helpers.cuh` via `kernel_def.hpp`. @@ -453,7 +502,7 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn rmm::device_uvector hashmap; std::atomic> last_touch; uint64_t param_hash; - cagra_bitset bitset; + cagra_sample_filter filter_payload; static inline auto calculate_parameter_hash( std::reference_wrapper> dataset_desc, @@ -481,7 +530,7 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn bool bitonic_sort_and_merge_multi_warps) -> uint64_t { (void)small_hash_bitlen; - (void)sample_filter; + const uint64_t filter_key = cagra_sample_filter_hash(sample_filter); const uint64_t bitonic_key = (topk_by_bitonic_sort ? 1ULL : 0ULL) ^ (bitonic_sort_and_merge_multi_warps ? 2ULL : 0ULL); return uint64_t(graph.data_handle()) ^ uint64_t(source_indices_ptr) ^ @@ -489,13 +538,14 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn hash_bitlen ^ small_hash_reset_interval ^ num_random_samplings ^ rand_xor_mask ^ num_seeds ^ itopk_size ^ search_width ^ min_iterations ^ max_iterations ^ uint64_t(persistent_lifetime * 1000) ^ uint64_t(persistent_device_usage * 1000) ^ - bitonic_key; + bitonic_key ^ filter_key; } static auto make_persistent_launcher( const dataset_descriptor_host& dataset_desc, bool topk_by_bitonic_sort, - bool bitonic_sort_and_merge_multi_warps) -> std::shared_ptr + bool bitonic_sort_and_merge_multi_warps, + SampleFilterT sample_filter) -> std::shared_ptr { auto launcher = make_cagra_single_cta_jit_launcher(sample_filter)); if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA persistent search kernel"); } return launcher; } @@ -535,8 +586,10 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn bool topk_by_bitonic_sort, bool bitonic_sort_and_merge_multi_warps) : persistent_runner_base_t{persistent_lifetime}, - launcher{make_persistent_launcher( - dataset_desc.get(), topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps)}, + launcher{make_persistent_launcher(dataset_desc.get(), + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + sample_filter)}, block_size{block_size}, worker_handles(0, stream, worker_handles_mr), job_descriptors(kMaxJobsNum, stream, job_descriptor_mr), @@ -567,9 +620,8 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps)) { - const auto bf = extract_cagra_sample_filter(sample_filter); - this->bitset = bf.bitset; - const uint32_t query_id_offset = bf.query_id_offset; + this->filter_payload = extract_cagra_sample_filter(sample_filter, stream); + const uint32_t query_id_offset = filter_payload.query_id_offset; // set kernel launch parameters dim3 gs = calc_coop_grid_size(block_size, smem_size, persistent_device_usage); @@ -656,7 +708,7 @@ struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runn small_hash_reset_interval_u32, // Cast size_t to uint32_t query_id_offset, // Offset to add to query_id when calling filter dev_desc, - bitset); + filter_payload); last_touch.store(std::chrono::system_clock::now(), std::memory_order_relaxed); } @@ -757,9 +809,8 @@ void select_and_run( const SourceIndexT* source_indices_ptr = source_indices.has_value() ? source_indices->data_handle() : nullptr; - const auto bf = extract_cagra_sample_filter(sample_filter); - const cagra_bitset bitset = bf.bitset; - const uint32_t query_id_offset = bf.query_id_offset; + const auto filter_payload = extract_cagra_sample_filter(sample_filter, stream); + const uint32_t query_id_offset = filter_payload.query_id_offset; // Use common logic to compute launch config auto config = compute_launch_config(num_itopk_candidates, ps.itopk_size, block_size); @@ -813,7 +864,8 @@ void select_and_run( dataset_desc, topk_by_bitonic_sort, bitonic_sort_and_merge_multi_warps, - false /* persistent */); + false /* persistent */, + make_cagra_sample_filter_udf_fragment(sample_filter)); if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA search kernel"); } // Get the device descriptor pointer - dev_ptr() initializes it if needed @@ -871,7 +923,7 @@ void select_and_run( query_id_offset, // Offset to add to query_id when calling filter dev_desc, static_cast(graph.extent(0)), - bitset); + filter_payload); }; cuvs::neighbors::detail::safely_launch_kernel_with_smem_size< diff --git a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp index a16c810409..e5157ffa6a 100644 --- a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp +++ b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp @@ -8,8 +8,8 @@ #include #include -#include "../../sample_filter.cuh" // For none_sample_filter, bitset_filter -#include "jit_lto_kernels/cagra_bitset.cuh" // is_bitset_filter, cagra_bitset, cagra_sample_filter, extract +#include "../../sample_filter.cuh" // For none_sample_filter, bitset_filter +#include "cagra_filter_payload.hpp" // sample-filter payload helpers #include #include @@ -100,19 +100,24 @@ struct sample_filter_jit_tag { using namespace cuvs::neighbors::filtering; if constexpr (std::is_same_v) { return cuvs::neighbors::detail::tag_filter_none{}; + } else if constexpr (is_udf_filter::value) { + return cuvs::neighbors::detail::tag_filter_udf{}; } else if constexpr (requires { std::declval().filter; }) { using InnerFilter = decltype(std::declval().filter); - if constexpr (is_bitset_filter::value || - std::is_same_v> || - std::is_same_v>) { + if constexpr (is_bitset_filter>::value || + std::is_same_v, bitset_filter> || + std::is_same_v, + bitset_filter>) { return cuvs::neighbors::detail::tag_filter_bitset{}; + } else if constexpr (is_udf_filter>::value) { + return cuvs::neighbors::detail::tag_filter_udf{}; } else { static_assert( cagra_jit_sample_filter_tag_type_always_false, "CAGRA JIT: sample_filter_jit_tag does not know how to link this filter. " - "CagraSampleFilterWithQueryIdOffset requires Inner of type " - "bitset_filter (see cagra_bitset.cuh is_bitset_filter and sample_filter_utils.cuh). " + "CagraSampleFilterWithQueryIdOffset requires Inner to be a supported " + "built-in filter or udf_filter (see cagra_filter_payload.hpp and " + "sample_filter_utils.cuh). " "For a new filter kind, add a sample_filter_jit_tag branch. " "(SAMPLE_FILTER_T in error; check InnerFilter in compiler output.)"); } @@ -120,9 +125,9 @@ struct sample_filter_jit_tag { static_assert( cagra_jit_sample_filter_tag_type_always_false, "CAGRA JIT: sample_filter_jit_tag: SAMPLE_FILTER_T must be cuvs::neighbors::filtering::" - "none_sample_filter, or " - "cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset<" - "bitset_filter>. Unknown wrapper type. " + "none_sample_filter, udf_filter, or " + "cuvs::neighbors::cagra::detail::CagraSampleFilterWithQueryIdOffset. " + "Unknown wrapper type. " "(SAMPLE_FILTER_T in error; add a branch in sample_filter_jit_tag.)"); } } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 1df9c8774b..9b96f94bf0 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -187,6 +187,13 @@ ConfigureTest( PERCENT 100 ) +ConfigureTest( + NAME NEIGHBORS_ANN_CAGRA_FILTER_UDF_TEST + PATH neighbors/ann_cagra/test_filter_udf.cu + GPUS 1 + PERCENT 100 +) + ConfigureTest( NAME NEIGHBORS_ANN_CAGRA_HELPERS_TEST PATH neighbors/ann_cagra/test_optimize_uint32_t.cu neighbors/ann_cagra/test_batch_load_iterator.cu diff --git a/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu b/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu new file mode 100644 index 0000000000..ba65df5d65 --- /dev/null +++ b/cpp/tests/neighbors/ann_cagra/test_filter_udf.cu @@ -0,0 +1,293 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra { +namespace { + +constexpr int64_t n_rows = 768; +constexpr int64_t n_dim = 16; +constexpr int64_t n_queries = 6; +constexpr int64_t k = 8; +constexpr int64_t threshold = 192; +constexpr int64_t high_filtering_threshold = 704; +constexpr float high_filtering_rate = + static_cast(high_filtering_threshold) / static_cast(n_rows); + +struct tenant_filter_context { + const uint32_t* row_tenants; + const uint32_t* query_tenants; +}; + +struct cagra_search_result { + std::vector neighbors; + std::vector distances; +}; + +std::string accept_all_udf_source() +{ + return R"cpp( + __device__ bool cuvs_filter_udf(uint32_t, source_index_t, void*) { return true; } + )cpp"; +} + +std::string threshold_udf_source() +{ + return R"cpp( + __device__ bool cuvs_filter_udf(uint32_t, source_index_t source_id, void*) + { + return source_id >= 192; + } + )cpp"; +} + +std::string reject_all_udf_source() +{ + return R"cpp( + __device__ bool cuvs_filter_udf(uint32_t, source_index_t, void*) { return false; } + )cpp"; +} + +std::string high_filtering_rate_udf_source() +{ + return R"cpp( + __device__ bool cuvs_filter_udf(uint32_t, source_index_t source_id, void*) + { + return source_id >= 704; + } + )cpp"; +} + +std::string tenant_udf_source() +{ + return R"cpp( + struct tenant_filter_context { + const uint32_t* row_tenants; + const uint32_t* query_tenants; + }; + + __device__ bool cuvs_filter_udf(uint32_t query_id, source_index_t source_id, void* filter_data) + { + auto* ctx = static_cast(filter_data); + return ctx->row_tenants[source_id] == ctx->query_tenants[query_id]; + } + )cpp"; +} + +void expect_same_results(cagra_search_result const& expected, cagra_search_result const& actual) +{ + ASSERT_EQ(expected.neighbors, actual.neighbors); + ASSERT_EQ(expected.distances.size(), actual.distances.size()); + for (size_t i = 0; i < expected.distances.size(); ++i) { + EXPECT_FLOAT_EQ(expected.distances[i], actual.distances[i]); + } +} + +class CagraUdfFilterTest : public ::testing::TestWithParam { + protected: + void SetUp() override + { + dataset.emplace(raft::make_device_matrix(res, n_rows, n_dim)); + queries.emplace(raft::make_device_matrix(res, n_queries, n_dim)); + + raft::random::RngState rng(1234ULL); + raft::random::uniform(res, rng, dataset->data_handle(), dataset->size(), -1.0f, 1.0f); + raft::random::uniform(res, rng, queries->data_handle(), queries->size(), -1.0f, 1.0f); + + cagra::index_params index_params; + index_params.metric = cuvs::distance::DistanceType::L2Expanded; + index_params.graph_degree = 32; + index_params.intermediate_graph_degree = 64; + index_params.graph_build_params = + cagra::graph_build_params::nn_descent_params(index_params.intermediate_graph_degree); + + index.emplace(cagra::build(res, index_params, raft::make_const_mdspan(dataset->view()))); + raft::resource::sync_stream(res); + } + + void TearDown() override + { + index.reset(); + queries.reset(); + dataset.reset(); + raft::resource::sync_stream(res); + } + + cagra_search_result search(cuvs::neighbors::filtering::base_filter const& filter, + float filtering_rate = -1.0f) + { + auto neighbors = raft::make_device_matrix(res, n_queries, k); + auto distances = raft::make_device_matrix(res, n_queries, k); + + cagra::search_params search_params; + search_params.algo = GetParam(); + search_params.itopk_size = 64; + search_params.max_queries = 2; + search_params.thread_block_size = 256; + search_params.filtering_rate = filtering_rate; + + cagra::search(res, + search_params, + *index, + raft::make_const_mdspan(queries->view()), + neighbors.view(), + distances.view(), + filter); + + auto stream = raft::resource::get_cuda_stream(res); + cagra_search_result result{std::vector(n_queries * k), + std::vector(n_queries * k)}; + raft::copy(result.neighbors.data(), neighbors.data_handle(), result.neighbors.size(), stream); + raft::copy(result.distances.data(), distances.data_handle(), result.distances.size(), stream); + raft::resource::sync_stream(res); + return result; + } + + raft::resources res; + std::optional> dataset = std::nullopt; + std::optional> queries = std::nullopt; + std::optional> index = std::nullopt; +}; + +TEST_P(CagraUdfFilterTest, AcceptAllMatchesNoFilter) +{ + cuvs::neighbors::filtering::none_sample_filter no_filter; + auto expected = search(no_filter, 0.0f); + + cuvs::neighbors::filtering::udf_filter udf_filter(accept_all_udf_source(), nullptr, 0.0f); + auto actual = search(udf_filter); + + expect_same_results(expected, actual); +} + +TEST_P(CagraUdfFilterTest, RejectAllReturnsNoValidNeighbors) +{ + cuvs::neighbors::filtering::udf_filter udf_filter(reject_all_udf_source(), nullptr, 0.999f); + auto result = search(udf_filter); + + // CAGRA algorithms do not all normalize empty-result sentinels the same way. Single-CTA + // clears the internal high-bit marker before writing output, so 0xffffffff can become + // 0x7fffffff; other paths may leave 0xffffffff. Both are invalid row ids for this index. + for (auto source_id : result.neighbors) { + EXPECT_GE(source_id, static_cast(n_rows)); + } +} + +TEST_P(CagraUdfFilterTest, HighFilteringRateReturnsOnlyValidNeighbors) +{ + cuvs::neighbors::filtering::udf_filter udf_filter( + high_filtering_rate_udf_source(), nullptr, high_filtering_rate); + auto result = search(udf_filter); + + for (auto source_id : result.neighbors) { + if (source_id < static_cast(n_rows)) { + EXPECT_GE(source_id, static_cast(high_filtering_threshold)); + } + } +} + +TEST_P(CagraUdfFilterTest, RepeatedUdfSearchWithSameSourceMatches) +{ + cuvs::neighbors::filtering::udf_filter udf_filter(accept_all_udf_source(), nullptr, 0.0f); + + auto first = search(udf_filter); + auto second = search(udf_filter); + + expect_same_results(first, second); +} + +TEST_P(CagraUdfFilterTest, InvalidSourceThrows) +{ + cuvs::neighbors::filtering::udf_filter udf_filter("this is not valid cuda source", nullptr, 0.0f); + + EXPECT_THROW(search(udf_filter), std::exception); +} + +TEST_P(CagraUdfFilterTest, ThresholdMatchesEquivalentBitset) +{ + auto removed_indices = raft::make_device_vector(res, threshold); + thrust::sequence(raft::resource::get_thrust_policy(res), + thrust::device_pointer_cast(removed_indices.data_handle()), + thrust::device_pointer_cast(removed_indices.data_handle() + threshold)); + raft::resource::sync_stream(res); + + cuvs::core::bitset removed_indices_bitset( + res, removed_indices.view(), n_rows); + cuvs::neighbors::filtering::bitset_filter bitset_filter(removed_indices_bitset.view()); + + float const filtering_rate = static_cast(threshold) / static_cast(n_rows); + auto expected = search(bitset_filter, filtering_rate); + + cuvs::neighbors::filtering::udf_filter udf_filter( + threshold_udf_source(), nullptr, filtering_rate); + auto actual = search(udf_filter, filtering_rate); + + expect_same_results(expected, actual); +} + +TEST_P(CagraUdfFilterTest, TenantContextHonorsQuerySpecificMetadata) +{ + std::vector host_row_tenants(n_rows); + std::vector host_query_tenants(n_queries); + for (int64_t i = 0; i < n_rows; ++i) { + host_row_tenants[static_cast(i)] = static_cast((i / 5) % 3); + } + for (int64_t q = 0; q < n_queries; ++q) { + host_query_tenants[static_cast(q)] = static_cast(q % 3); + } + + auto row_tenants = raft::make_device_vector(res, n_rows); + auto query_tenants = raft::make_device_vector(res, n_queries); + auto context = raft::make_device_vector(res, 1); + + auto stream = raft::resource::get_cuda_stream(res); + raft::copy(row_tenants.data_handle(), host_row_tenants.data(), host_row_tenants.size(), stream); + raft::copy( + query_tenants.data_handle(), host_query_tenants.data(), host_query_tenants.size(), stream); + + tenant_filter_context host_context{row_tenants.data_handle(), query_tenants.data_handle()}; + raft::copy(context.data_handle(), &host_context, 1, stream); + raft::resource::sync_stream(res); + + cuvs::neighbors::filtering::udf_filter udf_filter( + tenant_udf_source(), context.data_handle(), 2.0f / 3.0f); + auto result = search(udf_filter); + + for (int64_t q = 0; q < n_queries; ++q) { + auto query_tenant = host_query_tenants[static_cast(q)]; + for (int64_t i = 0; i < k; ++i) { + auto source_id = result.neighbors[static_cast(q * k + i)]; + ASSERT_LT(source_id, static_cast(n_rows)); + EXPECT_EQ(host_row_tenants[source_id], query_tenant); + } + } +} + +INSTANTIATE_TEST_CASE_P(CagraUdfFilters, + CagraUdfFilterTest, + ::testing::Values(cagra::search_algo::SINGLE_CTA, + cagra::search_algo::MULTI_CTA, + cagra::search_algo::MULTI_KERNEL)); + +} // namespace +} // namespace cuvs::neighbors::cagra diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 034b0b3d96..d63ddbdb71 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -31,6 +31,7 @@ include(../cmake/thirdparty/get_cuvs.cmake) # -------------- compile tasks ----------------- # add_executable(BRUTE_FORCE_EXAMPLE src/brute_force_bitmap.cu) add_executable(CAGRA_EXAMPLE src/cagra_example.cu) +add_executable(CAGRA_FILTER_UDF_EXAMPLE src/cagra_filter_udf_example.cu) add_executable(CAGRA_HNSW_ACE_EXAMPLE src/cagra_hnsw_ace_example.cu) add_executable(CAGRA_PERSISTENT_EXAMPLE src/cagra_persistent_example.cu) add_executable(DYNAMIC_BATCHING_EXAMPLE src/dynamic_batching_example.cu) @@ -44,6 +45,9 @@ add_executable(SCANN_EXAMPLE src/scann_example.cu) # installed in a conda environment, if one exists target_link_libraries(BRUTE_FORCE_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries(CAGRA_EXAMPLE PRIVATE cuvs::cuvs $) +target_link_libraries( + CAGRA_FILTER_UDF_EXAMPLE PRIVATE cuvs::cuvs $ +) target_link_libraries(CAGRA_HNSW_ACE_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries( CAGRA_PERSISTENT_EXAMPLE PRIVATE cuvs::cuvs $ Threads::Threads diff --git a/examples/cpp/src/cagra_filter_udf_example.cu b/examples/cpp/src/cagra_filter_udf_example.cu new file mode 100644 index 0000000000..0ab42dd580 --- /dev/null +++ b/examples/cpp/src/cagra_filter_udf_example.cu @@ -0,0 +1,251 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace { + +constexpr int64_t n_rows = 4096; +constexpr int64_t n_dim = 32; +constexpr int64_t n_queries = 4; +constexpr int64_t k = 8; + +struct metadata_filter_context { + const uint32_t* row_tenant_ids; + const int64_t* row_timestamps; + const uint32_t* row_language_ids; + const uint64_t* row_acl_masks; + + const uint32_t* query_tenant_ids; + const int64_t* query_min_timestamps; + const uint64_t* query_allowed_language_masks; + const uint64_t* query_permission_masks; +}; + +std::string metadata_udf_source() +{ + return R"cpp( + struct metadata_filter_context { + const uint32_t* row_tenant_ids; + const int64_t* row_timestamps; + const uint32_t* row_language_ids; + const uint64_t* row_acl_masks; + + const uint32_t* query_tenant_ids; + const int64_t* query_min_timestamps; + const uint64_t* query_allowed_language_masks; + const uint64_t* query_permission_masks; + }; + + __device__ bool tenant_filter(uint32_t query_id, source_index_t source_id, void* filter_data) + { + auto* ctx = static_cast(filter_data); + return ctx->row_tenant_ids[source_id] == ctx->query_tenant_ids[query_id]; + } + + __device__ bool timestamp_filter(uint32_t query_id, source_index_t source_id, void* filter_data) + { + auto* ctx = static_cast(filter_data); + return ctx->row_timestamps[source_id] >= ctx->query_min_timestamps[query_id]; + } + + __device__ bool language_acl_filter(uint32_t query_id, source_index_t source_id, void* filter_data) + { + auto* ctx = static_cast(filter_data); + const auto language_bit = uint64_t{1} << ctx->row_language_ids[source_id]; + const bool language_ok = (ctx->query_allowed_language_masks[query_id] & language_bit) != 0; + const bool acl_ok = (ctx->row_acl_masks[source_id] & ctx->query_permission_masks[query_id]) != 0; + return language_ok && acl_ok; + } + )cpp"; +} + +template +void copy_to_device(raft::device_resources const& res, DeviceVectorT& dst, HostVectorT const& src) +{ + raft::copy(dst.data_handle(), src.data(), src.size(), raft::resource::get_cuda_stream(res)); +} + +std::vector copy_neighbors_to_host( + raft::device_resources const& res, + raft::device_matrix_view neighbors) +{ + std::vector host(neighbors.size()); + raft::copy( + host.data(), neighbors.data_handle(), host.size(), raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + return host; +} + +template +void validate_and_print(char const* name, + raft::device_resources const& res, + raft::device_matrix_view neighbors, + PredicateT is_valid) +{ + auto host_neighbors = copy_neighbors_to_host(res, neighbors); + + std::cout << name << " first query neighbors:"; + for (int64_t i = 0; i < k; ++i) { + std::cout << " " << host_neighbors[static_cast(i)]; + } + std::cout << std::endl; + + for (int64_t q = 0; q < n_queries; ++q) { + for (int64_t i = 0; i < k; ++i) { + auto source_id = host_neighbors[static_cast(q * k + i)]; + if (source_id < static_cast(n_rows) && !is_valid(q, source_id)) { + std::cerr << name << " produced invalid source_id=" << source_id << " for query=" << q + << std::endl; + std::exit(1); + } + } + } +} + +} // namespace + +int main() +{ + raft::device_resources res; + + rmm::mr::pool_memory_resource pool_mr(rmm::mr::get_current_device_resource_ref(), + 1024 * 1024 * 1024ull); + rmm::mr::set_current_device_resource(pool_mr); + + auto dataset = raft::make_device_matrix(res, n_rows, n_dim); + auto queries = raft::make_device_matrix(res, n_queries, n_dim); + + raft::random::RngState rng(1234ULL); + raft::random::uniform(res, rng, dataset.data_handle(), dataset.size(), -1.0f, 1.0f); + raft::random::uniform(res, rng, queries.data_handle(), queries.size(), -1.0f, 1.0f); + + cuvs::neighbors::cagra::index_params index_params; + index_params.metric = cuvs::distance::DistanceType::L2Expanded; + index_params.graph_degree = 32; + index_params.intermediate_graph_degree = 64; + index_params.graph_build_params = cuvs::neighbors::cagra::graph_build_params::nn_descent_params( + index_params.intermediate_graph_degree); + + std::cout << "Building CAGRA index" << std::endl; + auto index = + cuvs::neighbors::cagra::build(res, index_params, raft::make_const_mdspan(dataset.view())); + + std::vector row_tenant_ids(n_rows); + std::vector row_timestamps(n_rows); + std::vector row_language_ids(n_rows); + std::vector row_acl_masks(n_rows); + for (int64_t i = 0; i < n_rows; ++i) { + row_tenant_ids[static_cast(i)] = static_cast(i % 4); + row_timestamps[static_cast(i)] = 1'700'000'000 + i; + row_language_ids[static_cast(i)] = static_cast(i % 8); + row_acl_masks[static_cast(i)] = uint64_t{1} << (i % 16); + } + + std::vector query_tenant_ids{0, 1, 2, 3}; + std::vector query_min_timestamps{ + 1'700'003'000, 1'700'002'000, 1'700'001'000, 1'700'000'500}; + std::vector query_allowed_language_masks{(uint64_t{1} << 0) | (uint64_t{1} << 1), + (uint64_t{1} << 2) | (uint64_t{1} << 3), + (uint64_t{1} << 4) | (uint64_t{1} << 5), + (uint64_t{1} << 6) | (uint64_t{1} << 7)}; + std::vector query_permission_masks{(uint64_t{1} << 0) | (uint64_t{1} << 8), + (uint64_t{1} << 2) | (uint64_t{1} << 10), + (uint64_t{1} << 4) | (uint64_t{1} << 12), + (uint64_t{1} << 6) | (uint64_t{1} << 14)}; + + auto row_tenant_ids_device = raft::make_device_vector(res, n_rows); + auto row_timestamps_device = raft::make_device_vector(res, n_rows); + auto row_language_ids_device = raft::make_device_vector(res, n_rows); + auto row_acl_masks_device = raft::make_device_vector(res, n_rows); + auto query_tenant_ids_device = raft::make_device_vector(res, n_queries); + auto query_min_timestamps_device = raft::make_device_vector(res, n_queries); + auto query_allowed_language_masks_device = + raft::make_device_vector(res, n_queries); + auto query_permission_masks_device = raft::make_device_vector(res, n_queries); + auto context_device = raft::make_device_vector(res, 1); + + copy_to_device(res, row_tenant_ids_device, row_tenant_ids); + copy_to_device(res, row_timestamps_device, row_timestamps); + copy_to_device(res, row_language_ids_device, row_language_ids); + copy_to_device(res, row_acl_masks_device, row_acl_masks); + copy_to_device(res, query_tenant_ids_device, query_tenant_ids); + copy_to_device(res, query_min_timestamps_device, query_min_timestamps); + copy_to_device(res, query_allowed_language_masks_device, query_allowed_language_masks); + copy_to_device(res, query_permission_masks_device, query_permission_masks); + + metadata_filter_context host_context{row_tenant_ids_device.data_handle(), + row_timestamps_device.data_handle(), + row_language_ids_device.data_handle(), + row_acl_masks_device.data_handle(), + query_tenant_ids_device.data_handle(), + query_min_timestamps_device.data_handle(), + query_allowed_language_masks_device.data_handle(), + query_permission_masks_device.data_handle()}; + raft::copy(context_device.data_handle(), &host_context, 1, raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); + + auto neighbors = raft::make_device_matrix(res, n_queries, k); + auto distances = raft::make_device_matrix(res, n_queries, k); + + cuvs::neighbors::cagra::search_params search_params; + search_params.algo = cuvs::neighbors::cagra::search_algo::MULTI_CTA; + search_params.itopk_size = 128; + search_params.max_queries = n_queries; + search_params.thread_block_size = 256; + + auto source = metadata_udf_source(); + + auto run_filter = + [&](char const* label, char const* function_name, float filtering_rate, auto is_valid) { + auto filter = cuvs::neighbors::filtering::udf_filter( + source, context_device.data_handle(), filtering_rate, function_name); + cuvs::neighbors::cagra::search(res, + search_params, + index, + raft::make_const_mdspan(queries.view()), + neighbors.view(), + distances.view(), + filter); + validate_and_print(label, res, neighbors.view(), is_valid); + }; + + run_filter("tenant_filter", "tenant_filter", 0.75f, [&](int64_t query_id, uint32_t source_id) { + return row_tenant_ids[source_id] == query_tenant_ids[query_id]; + }); + + run_filter( + "timestamp_filter", "timestamp_filter", 0.50f, [&](int64_t query_id, uint32_t source_id) { + return row_timestamps[source_id] >= query_min_timestamps[query_id]; + }); + + run_filter( + "language_acl_filter", + "language_acl_filter", + 0.875f, + [&](int64_t query_id, uint32_t source_id) { + const auto language_bit = uint64_t{1} << row_language_ids[source_id]; + const bool language_ok = (query_allowed_language_masks[query_id] & language_bit) != 0; + const bool acl_ok = (row_acl_masks[source_id] & query_permission_masks[query_id]) != 0; + return language_ok && acl_ok; + }); + + std::cout << "All CAGRA filter UDF examples produced valid filtered neighbors." << std::endl; + return 0; +} diff --git a/fern/pages/neighbors/cagra.md b/fern/pages/neighbors/cagra.md index 5f0496470a..c1eb80faea 100644 --- a/fern/pages/neighbors/cagra.md +++ b/fern/pages/neighbors/cagra.md @@ -915,6 +915,43 @@ if err != nil { +### C++ filter UDFs + +CAGRA C++ search can also use a JIT-LTO filter UDF when the predicate needs device metadata such as per-row tenants, ACLs, timestamps, or query-specific attributes. The UDF is a candidate-validity predicate: it decides whether a logical `source_id` is allowed for a logical `query_id`. It cannot control graph traversal, access PQ/VPQ or graph internals, or change distance computation. + +```cpp +struct tenant_filter_context { + const uint32_t* row_tenants; + const uint32_t* query_tenants; +}; + +std::string source = R"cpp( +struct tenant_filter_context { + const uint32_t* row_tenants; + const uint32_t* query_tenants; +}; + +__device__ bool cuvs_filter_udf(uint32_t query_id, + source_index_t source_id, + void* filter_data) +{ + auto* ctx = static_cast(filter_data); + return ctx->row_tenants[source_id] == ctx->query_tenants[query_id]; +} +)cpp"; + +tenant_filter_context host_ctx{row_tenants_device, query_tenants_device}; +tenant_filter_context* ctx_device = copy_to_device(host_ctx); + +auto filter = cuvs::neighbors::filtering::udf_filter(source, ctx_device, 0.75f); + +cagra::search(res, search_params, index, queries, neighbors, distances, filter); +``` + +The `filter_data` pointer is passed through unchanged to the device predicate. If the predicate dereferences it, the pointer and any nested pointers must refer to device-accessible memory and remain valid for the duration of the search. The `query_id` passed to the UDF is the global logical query id, including the batch offset when `max_queries` causes CAGRA to split a search into multiple batches. + +If `search_params::filtering_rate` is negative, CAGRA uses `udf_filter::filtering_rate`. If both are negative, CAGRA assumes `0.0` because UDF selectivity cannot be inferred from arbitrary CUDA source. Because CAGRA remains approximate, filtered results are not guaranteed to match exact brute-force filtered search, especially for highly selective predicates without an accurate `filtering_rate`. + ## Configuration parameters ### Build parameters diff --git a/fern/pages/working_with_ann_indexes.md b/fern/pages/working_with_ann_indexes.md index 12b3e00c63..3a299fdc10 100644 --- a/fern/pages/working_with_ann_indexes.md +++ b/fern/pages/working_with_ann_indexes.md @@ -19,3 +19,21 @@ For CAGRA bitset examples, see [Using Filters](/user-guide/api-guides/indexing-g A bitmap is based on the same principle as a bitset, but in two dimensions. This allows users to provide a different bitset for each query being searched. See the RAFT [bitmap API documentation](https://docs.rapids.ai/api/raft/stable/cpp_api/core_bitmap/) for more information. For Brute-force bitmap examples, see [Using Filters](/user-guide/api-guides/indexing-guide/brute-force#using-filters). + +### CAGRA filter UDF + +CAGRA also supports a low-level JIT-LTO filter UDF for C++ predicates that are more naturally expressed as CUDA device code. The UDF source defines a device function that returns `true` when a source vector is allowed and `false` when it should be rejected: + +```cpp +__device__ bool cuvs_filter_udf(uint32_t query_id, + source_index_t source_id, + void* filter_data); +``` + +`source_index_t` is currently `uint32_t` for CAGRA. `filter_data` is an opaque pointer passed through to the device predicate; if the UDF dereferences it, the pointer and any nested pointers must refer to device-accessible memory and remain valid for the duration of the search. The `query_id` passed to the UDF is the logical query id, including the batch offset when CAGRA splits a search into multiple query batches. + +When `cagra::search_params::filtering_rate` is negative, CAGRA uses `filtering::udf_filter::filtering_rate`. If both are negative, CAGRA assumes `0.0` because it cannot infer UDF selectivity from the source string. Providing a realistic filtering rate helps CAGRA size its internal search work for selective filters. + +Filter UDFs are candidate-validity predicates only. They receive logical query and source identifiers plus the caller-provided context pointer; they do not expose CAGRA graph traversal state, IVF probing decisions, PQ/VPQ encoded data, or other internal index layouts. NVIDIA cuVS still owns traversal, distance computation, and result selection. + +Filtered CAGRA search remains approximate ANN search. The UDF prevents rejected candidates from appearing in returned results, but it does not guarantee exact brute-force filtered nearest-neighbor semantics. For CAGRA examples, see [Using Filters](/user-guide/api-guides/indexing-guide/cagra#using-filters).