-
Notifications
You must be signed in to change notification settings - Fork 194
Add JIT-LTO based filter UDF support for CAGRA #2132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7c7070a
f6aeada
292f7e8
296d44c
67a03dd
1365f92
95aabc3
51716a6
3f021bb
5407d06
ee1475a
aa01651
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,256 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights | ||
| * reserved. SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include "../../sample_filter.cuh" // public filter types | ||
| #include "../sample_filter_data.cuh" | ||
| #include "jit_lto_kernels/cagra_filter_payload.cuh" | ||
|
|
||
| #include <raft/core/error.hpp> | ||
|
|
||
| #include <cuda_runtime_api.h> | ||
|
|
||
| #include <cstddef> | ||
| #include <cstdint> | ||
| #include <cstring> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <type_traits> | ||
| #include <unordered_map> | ||
| #include <vector> | ||
|
|
||
| namespace cuvs::neighbors::cagra::detail { | ||
|
|
||
| template <typename SourceIndexT> | ||
| using cagra_filter_data_storage = ::cuvs::neighbors::detail::bitset_filter_data_t<SourceIndexT>; | ||
|
|
||
| template <typename PayloadT> | ||
| std::uint64_t cagra_payload_hash(PayloadT const& payload) | ||
| { | ||
| static_assert(std::is_trivially_copyable_v<PayloadT>); | ||
| constexpr std::uint64_t kOffset = 1469598103934665603ull; | ||
| constexpr std::uint64_t kPrime = 1099511628211ull; | ||
| auto const* bytes = reinterpret_cast<unsigned char const*>(&payload); | ||
| std::uint64_t hash = kOffset; | ||
| for (std::size_t i = 0; i < sizeof(PayloadT); ++i) { | ||
| hash ^= bytes[i]; | ||
| hash *= kPrime; | ||
| } | ||
| return hash; | ||
| } | ||
|
|
||
| template <typename PayloadT> | ||
| struct cagra_device_payload_owner { | ||
| struct state { | ||
| PayloadT host_payload{}; | ||
| PayloadT* device_payload{nullptr}; | ||
| cudaStream_t stream{}; | ||
| cudaEvent_t ready_event{}; | ||
| int device{-1}; | ||
| std::mutex mutex; | ||
|
|
||
| explicit state(PayloadT payload) : host_payload(payload) {} | ||
|
|
||
| ~state() noexcept | ||
| { | ||
| if (device_payload != nullptr) { | ||
| RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(device_payload, stream)); | ||
| } | ||
| if (ready_event != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(ready_event)); } | ||
| } | ||
|
|
||
| PayloadT* dev_ptr(cudaStream_t cuda_stream) | ||
| { | ||
| std::lock_guard<std::mutex> lock(mutex); | ||
| if (device_payload == nullptr) { | ||
| RAFT_CUDA_TRY(cudaGetDevice(&device)); | ||
| RAFT_CUDA_TRY(cudaMallocAsync( | ||
| reinterpret_cast<void**>(&device_payload), sizeof(PayloadT), cuda_stream)); | ||
| RAFT_CUDA_TRY(cudaMemcpyAsync( | ||
| device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream)); | ||
| RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming)); | ||
| RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream)); | ||
| stream = cuda_stream; | ||
| } else { | ||
| RAFT_CUDA_TRY(cudaStreamWaitEvent(cuda_stream, ready_event, 0)); | ||
| } | ||
| return device_payload; | ||
| } | ||
| }; | ||
|
|
||
| // PayloadT is copied to device by value. Pointer fields inside PayloadT are shallow-copied and | ||
| // must already point to device-addressable memory that remains valid while the cached payload is | ||
| // usable. | ||
| struct cache_key { | ||
| std::uint64_t payload_hash{}; | ||
| int device{}; | ||
|
|
||
| bool operator==(cache_key const& other) const | ||
| { | ||
| return payload_hash == other.payload_hash && device == other.device; | ||
| } | ||
| }; | ||
|
|
||
| struct cache_key_hash { | ||
| std::size_t operator()(cache_key const& key) const | ||
| { | ||
| auto seed = static_cast<std::size_t>(key.payload_hash); | ||
| seed ^= static_cast<std::size_t>(key.device) + 0x9e3779b9 + (seed << 6) + (seed >> 2); | ||
| return seed; | ||
| } | ||
| }; | ||
|
|
||
| cagra_device_payload_owner() = default; | ||
|
|
||
| void* dev_ptr(PayloadT payload, cudaStream_t stream) const | ||
| { | ||
| int device{}; | ||
| RAFT_CUDA_TRY(cudaGetDevice(&device)); | ||
|
|
||
| // Keep cached payload copies for process lifetime to avoid per-search allocation/copy churn. | ||
| // Cross-stream reuse is ordered by each state's ready_event before kernels consume the pointer. | ||
| const auto key = cache_key{cagra_payload_hash(payload), device}; | ||
| std::shared_ptr<state> selected_state; | ||
| { | ||
| std::lock_guard<std::mutex> lock(cache_mutex_); | ||
| auto& entries = cache_[key]; | ||
| for (auto const& cached : entries) { | ||
| if (std::memcmp(&cached->host_payload, &payload, sizeof(PayloadT)) == 0) { | ||
| selected_state = cached; | ||
| break; | ||
| } | ||
| } | ||
| if (selected_state == nullptr) { | ||
| selected_state = std::make_shared<state>(payload); | ||
| entries.push_back(selected_state); | ||
| } | ||
| } | ||
|
|
||
| return selected_state->dev_ptr(stream); | ||
| } | ||
|
|
||
| private: | ||
| mutable std::mutex cache_mutex_; | ||
| mutable std::unordered_map<cache_key, std::vector<std::shared_ptr<state>>, cache_key_hash> cache_; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor nit: |
||
| }; | ||
|
|
||
| template <typename T> | ||
| struct is_bitset_filter : std::false_type {}; | ||
|
|
||
| template <typename bitset_t, typename index_t> | ||
| struct is_bitset_filter<::cuvs::neighbors::filtering::bitset_filter<bitset_t, index_t>> | ||
| : std::true_type {}; | ||
|
|
||
| template <typename T> | ||
| struct is_udf_filter : std::false_type {}; | ||
|
|
||
| template <> | ||
| struct is_udf_filter<::cuvs::neighbors::filtering::udf_filter> : std::true_type {}; | ||
|
|
||
| template <typename SourceIndexT, typename FilterT> | ||
| cagra_filter_data_storage<SourceIndexT> make_cagra_filter_data_storage(const FilterT& filter) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok to address later: name this |
||
| { | ||
| const auto bitset_view = filter.view(); | ||
| return cagra_filter_data_storage<SourceIndexT>{ | ||
| const_cast<std::uint32_t*>(bitset_view.data()), | ||
| static_cast<SourceIndexT>(bitset_view.size()), | ||
| static_cast<SourceIndexT>(bitset_view.get_original_nbits())}; | ||
| } | ||
|
|
||
| template <typename PayloadT> | ||
| void* get_cagra_device_payload(PayloadT payload, cudaStream_t stream) | ||
| { | ||
| static cagra_device_payload_owner<PayloadT> owner; | ||
| return owner.dev_ptr(payload, stream); | ||
| } | ||
|
Comment on lines
+162
to
+167
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Call this in the factory? So the factory returns a perfectly valid device payload directly |
||
|
|
||
| template <typename SourceIndexT, typename FilterT> | ||
| void fill_cagra_sample_filter(cagra_sample_filter<SourceIndexT>& out, | ||
| const FilterT& filter, | ||
| cudaStream_t stream) | ||
| { | ||
| using DecayedFilter = std::decay_t<FilterT>; | ||
| if constexpr (is_bitset_filter<DecayedFilter>::value) { | ||
| out.filter_data = get_cagra_device_payload(make_cagra_filter_data_storage<SourceIndexT>(filter), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would not need to call here then, going by above comment |
||
| stream); | ||
| } else if constexpr (is_udf_filter<DecayedFilter>::value) { | ||
| out.filter_data = filter.filter_data; | ||
| } | ||
| } | ||
|
|
||
| template <typename SourceIndexT, typename FilterT> | ||
| std::uint64_t cagra_filter_payload_hash(const FilterT& filter) | ||
| { | ||
| using DecayedFilter = std::decay_t<FilterT>; | ||
| if constexpr (is_bitset_filter<DecayedFilter>::value) { | ||
| return cagra_payload_hash(make_cagra_filter_data_storage<SourceIndexT>(filter)); | ||
| } else if constexpr (requires { filter.filter; }) { | ||
| return cagra_filter_payload_hash<SourceIndexT>(filter.filter); | ||
| } else { | ||
| return 0; | ||
| } | ||
| } | ||
|
|
||
| template <typename FilterT> | ||
| void* cagra_filter_data_ptr(const FilterT& filter) | ||
| { | ||
| using DecayedFilter = std::decay_t<FilterT>; | ||
| if constexpr (is_udf_filter<DecayedFilter>::value) { | ||
| return filter.filter_data; | ||
| } else if constexpr (requires { filter.filter; }) { | ||
| return cagra_filter_data_ptr(filter.filter); | ||
| } else { | ||
| return nullptr; | ||
| } | ||
| } | ||
|
|
||
| template <typename SampleFilterT> | ||
| std::uint32_t cagra_filter_query_id_offset(const SampleFilterT& sample_filter) | ||
| { | ||
| if constexpr (requires { | ||
| sample_filter.filter; | ||
| sample_filter.offset; | ||
| }) { | ||
| return sample_filter.offset; | ||
| } else { | ||
| return 0; | ||
| } | ||
| } | ||
|
|
||
| /// Host: fill @ref cagra_sample_filter from a CAGRA filter object. | ||
| template <typename SourceIndexT, typename SampleFilterT> | ||
| cagra_sample_filter<SourceIndexT> extract_cagra_sample_filter(const SampleFilterT& sample_filter, | ||
| cudaStream_t stream) | ||
| { | ||
| cagra_sample_filter<SourceIndexT> out; | ||
| if constexpr (requires { | ||
| sample_filter.filter; | ||
| sample_filter.offset; | ||
| }) { | ||
| out.query_id_offset = sample_filter.offset; | ||
| fill_cagra_sample_filter(out, sample_filter.filter, stream); | ||
| } else { | ||
| fill_cagra_sample_filter(out, sample_filter, stream); | ||
| } | ||
| return out; | ||
| } | ||
|
|
||
| /// Host: find UDF compile/link metadata only. Query offsets stay in the runtime payload produced | ||
| /// by @ref extract_cagra_sample_filter and are applied before calling the linked sample_filter. | ||
| template <typename SampleFilterT> | ||
| const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( | ||
| const SampleFilterT& sample_filter) | ||
| { | ||
| using DecayedFilter = std::decay_t<SampleFilterT>; | ||
| if constexpr (is_udf_filter<DecayedFilter>::value) { | ||
| return &sample_filter; | ||
| } else if constexpr (requires { sample_filter.filter; }) { | ||
| return get_cagra_udf_filter(sample_filter.filter); | ||
| } else { | ||
| return nullptr; | ||
| } | ||
| } | ||
|
|
||
| } // namespace cuvs::neighbors::cagra::detail | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 2714
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 7203
🏁 Script executed:
Repository: rapidsai/cuvs
Length of output: 1948
HIGH: Partially-initialized cached device payload on CUDA failure (missing rollback)
cagra_device_payload_owner::dev_ptr()inserts a newstateintocache_[key]beforestate::dev_ptr()runs. If anystate::dev_ptr()CUDA call fails aftercudaMallocAsyncsucceeds (e.g.,cudaMemcpyAsync,cudaEventCreateWithFlags,cudaEventRecord), the cachedstatecan persist withdevice_payload != nullptrwhileready_eventremains default/invalid. Subsequent calls take theelsepath and callcudaStreamWaitEvent(cuda_stream, ready_event, 0), and the leaked allocation can persist for process lifetime since the cached entry is never removed.Suggested fix
PayloadT* dev_ptr(cudaStream_t cuda_stream) { std::lock_guard<std::mutex> lock(mutex); if (device_payload == nullptr) { RAFT_CUDA_TRY(cudaGetDevice(&device)); - RAFT_CUDA_TRY(cudaMallocAsync( - reinterpret_cast<void**>(&device_payload), sizeof(PayloadT), cuda_stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync( - device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream)); - RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming)); - RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream)); - stream = cuda_stream; + PayloadT* new_device_payload{nullptr}; + cudaEvent_t new_ready_event{}; + try { + RAFT_CUDA_TRY(cudaMallocAsync( + reinterpret_cast<void**>(&new_device_payload), sizeof(PayloadT), cuda_stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync(new_device_payload, + &host_payload, + sizeof(PayloadT), + cudaMemcpyHostToDevice, + cuda_stream)); + RAFT_CUDA_TRY(cudaEventCreateWithFlags(&new_ready_event, cudaEventDisableTiming)); + RAFT_CUDA_TRY(cudaEventRecord(new_ready_event, cuda_stream)); + device_payload = new_device_payload; + ready_event = new_ready_event; + stream = cuda_stream; + } catch (...) { + if (new_ready_event != nullptr) { + RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(new_ready_event)); + } + if (new_device_payload != nullptr) { + RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(new_device_payload, cuda_stream)); + } + throw; + } } else { RAFT_CUDA_TRY(cudaStreamWaitEvent(cuda_stream, ready_event, 0)); } return device_payload; }🤖 Prompt for AI Agents