Add JIT-LTO based filter UDF support for CAGRA#2132
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
| /** Optional stable cache key for equivalent generated source. */ | ||
| std::string cache_key; | ||
| /** Device function name to call from the generated CAGRA sample filter. */ | ||
| std::string function_name = "cuvs_filter_udf"; |
There was a problem hiding this comment.
Why expose these 2? We can just hide it within implementation details. As you noted previously in IVF Flat, the cache key can just be the string itself
There was a problem hiding this comment.
Agree on cache_key; it is an implementation detail and can be derived internally from the source/generated code.
For function_name, the intent was to support one source string with several predicates over the same metadata context, e.g. tenant_filter, timestamp_filter, and acl_filter, and let the user select which one to link without regenerating wrapper source each time. What do you think?
There was a problem hiding this comment.
Oh that is pretty smart, so you just pay compilation costs once.
| /// Host/device payload for linked CAGRA sample filters plus query offset for wrapped filters. | ||
| template <typename SourceIndexT> | ||
| struct cagra_sample_filter { | ||
| cagra_bitset<SourceIndexT> bitset{}; |
There was a problem hiding this comment.
Can we hide this behind the opaque filter_data? bitset is an implementation detail of the bitset_filter
There was a problem hiding this comment.
refactored this into a generic cagra_filter_payload.cuh and renamed the embedded storage to filter_data_storage in 296d44c
| out.bitset.bitset_ptr = const_cast<std::uint32_t*>(bitset_view.data()); | ||
| out.bitset.bitset_len = static_cast<SourceIndexT>(bitset_view.size()); | ||
| out.bitset.original_nbits = static_cast<SourceIndexT>(bitset_view.get_original_nbits()); |
There was a problem hiding this comment.
So going by my above comment, you would assign out.filter_data = cagra_bitset<SourceIndexT>{bitset_ptr, bitset_len, original_nbits} and then reinterpret this in the fragment
| template <typename SourceIndexT> | ||
| __device__ __forceinline__ void* get_cagra_sample_filter_data( | ||
| cagra_sample_filter<SourceIndexT>& payload) | ||
| { | ||
| if (payload.filter_kind == cagra_filter_kind::udf) { return payload.filter_data; } | ||
| if (payload.filter_kind == cagra_filter_kind::bitset) { return &payload.bitset; } | ||
| return nullptr; | ||
| } |
There was a problem hiding this comment.
If you make bitset opaque then you don't need this function
|
Worried about impact? Review this PR in Change Stack to explore blast radius before you approve or request changes. Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughSummary by CodeRabbit
WalkthroughAdds UDF-based device filters to CAGRA: new ChangesCAGRA UDF Filter Implementation
Tests and Documentation
🎯 4 (Complex) | ⏱️ ~60 minutes
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (4)
cpp/include/cuvs/neighbors/common.hpp (1)
634-658: 💤 Low valueConsider validating required UDF fields in the constructor.
The
udf_filterconstructor acceptssourceandfunction_namebut does not validate that they are non-empty. While validation occurs later during JIT compilation (inmake_cagra_sample_filter_udf_fragment), failing early in the constructor would provide clearer error messages and better user experience.💡 Suggested defensive check
explicit udf_filter(std::string source, void* filter_data = nullptr, float filtering_rate = -1.0f, std::string function_name = "cuvs_filter_udf") : source(std::move(source)), filter_data(filter_data), filtering_rate(filtering_rate), function_name(std::move(function_name)) { + // Optional: validate required fields early + // RAFT_EXPECTS(!this->source.empty(), "UDF source must not be empty"); + // RAFT_EXPECTS(!this->function_name.empty(), "UDF function name must not be empty"); }🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/include/cuvs/neighbors/common.hpp` around lines 634 - 658, The udf_filter constructor should validate that required strings are non-empty; update the explicit udf_filter(std::string source, void* filter_data, float filtering_rate, std::string function_name) to check source and function_name and throw a clear exception (e.g., std::invalid_argument) or call an assert if either is empty, so invalid UDFs fail fast (this is the same validation later done in make_cagra_sample_filter_udf_fragment but should be performed at construction time in udf_filter).cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh (1)
34-42: 💤 Low valueConsider namespace isolation for generated type aliases.
Lines 34-40 inject type aliases (
int8_t,uint32_t, etc.) into the global namespace of the generated CUDA source. While this is generally safe since NVRTC compiles each fragment in isolation, explicitly scoping these types or documenting the naming convention would prevent potential confusion if users' UDF code expects these types from<cstdint>.The current approach works because:
- Generated source is compiled separately via NVRTC
- User UDF code will see these definitions
However, for robustness, consider either:
- Adding a comment explaining why global aliases are safe here
- Or including actual headers (
#include <cstdint>) in the generated source🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh` around lines 34 - 42, The generated CUDA fragment in sample_filter_udf.cuh currently injects fundamental type aliases via the oss R"(...)" block (e.g., int8_t, uint32_t, source_index_t) into the global namespace; modify the generator to either (A) wrap these aliases in a dedicated namespace (e.g., namespace jit_types { using int8_t = signed char; ... using source_index_t = <source_index_type>; }) and update any emitted code to reference jit_types::source_index_t, or (B) include the proper header by emitting `#include` <cstdint> at the top of the generated source and only emit the project-specific alias source_index_t (still qualified or namespaced). Update uses of source_index_t in the generated UDF to reference the chosen namespace if you choose (A); keep changes localized to the oss string assembly in sample_filter_udf.cuh.cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh (1)
54-59: ⚡ Quick winClarify const_cast usage for CAGRA bitset payload
bitset_filter_data_t::bitset_ptris defined asstd::uint32_t*(non-const), so theconst_cast<std::uint32_t*>(bitset_view.data())incagra_filter_payload.cuhis required to assignfilter.view().data()into that payload type; the device code then treats the bitset as read-only (builds aconstraft::core::bitset_viewand only callstest).Optional: change
bitset_filter_data_t::bitset_ptr(and corresponding kernel parameter types) toconst std::uint32_t*to remove the cast, or document explicitly that the bitset is immutable during search.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh` around lines 54 - 59, The code uses const_cast to assign filter.view().data() into cagra_filter_data_storage because bitset_filter_data_t::bitset_ptr is a non-const std::uint32_t*; replace this unsafe cast by making the payload and kernel types explicitly const: change bitset_filter_data_t::bitset_ptr (and any kernel parameter types and cagra_filter_data_storage template instantiations) to const std::uint32_t* so the bitset is treated as immutable, and update any affected functions (e.g., places constructing cagra_filter_data_storage in cagra_filter_payload.cuh and kernel signatures that consume it) to accept the const pointer, or alternatively add a clear comment next to bitset_filter_data_t::bitset_ptr and the const raft::core::bitset_view usage documenting that the bitset is immutable during search if you prefer to keep the current type.cpp/tests/neighbors/ann_cagra/test_filter_udf.cu (1)
106-170: ⚡ Quick winConsider adding edge case tests for robustness.
The test fixture provides good coverage of the UDF filtering mechanism, but per coding guidelines, numerical correctness tests should validate edge cases. Consider adding tests for:
- Empty dataset or query set (n_rows=0 or n_queries=0)
- Single-element dataset (n_rows=1)
- Requesting k neighbors when fewer than k rows pass the filter
These edge cases would help ensure the UDF filtering path handles boundary conditions consistently with other CAGRA filter types.
As per coding guidelines: Numerical correctness tests must validate edge cases: empty inputs, single elements, zero-norm vectors, identical points, and extreme values.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tests/neighbors/ann_cagra/test_filter_udf.cu` around lines 106 - 170, Add edge-case unit tests using the existing CagraUdfFilterTest fixture: create new TEST_P cases that override n_rows and n_queries (set to 0 and 1 as needed) and invoke the search(...) helper with UDF filters to validate behavior for empty dataset, empty queries, single-row dataset, and requesting k larger than available passing rows; ensure you assert on cagra_search_result.neighbors and .distances for expected shapes/values and that no crashes or undefined-copy occur (use the existing search function, the class CagraUdfFilterTest, and the cagra_search_result return value to drive the checks).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/src/neighbors/cagra.cuh`:
- Around line 392-399: The code currently uses std::min/std::max to clamp
sample_filter.filtering_rate into params_copy.filtering_rate but does not guard
against NaN/infinite values, so a NaN in sample_filter.filtering_rate will
propagate; update the block (around params.filtering_rate,
params_copy.filtering_rate, and sample_filter.filtering_rate) to first validate
finiteness (e.g., std::isfinite or !std::isnan) and if the value is not finite
either reject it or set params_copy.filtering_rate to a safe default (0.0f)
before applying std::min/std::max, and ensure any rejection path emits an
appropriate error/return per existing error handling policy.
In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh`:
- Around line 21-27: Update the public Doxygen for the udf_filter API to note
that CAGRA-generated JIT fragments use source_index_t == uint32_t (i.e.,
SourceIndexT is currently restricted to uint32_t), so UDF filters must expect a
uint32_t source index; reference the internal symbol
cagra_udf_source_index_type_name and the typedef/source_index_t in the comment
so users know this constraint up front.
---
Nitpick comments:
In `@cpp/include/cuvs/neighbors/common.hpp`:
- Around line 634-658: The udf_filter constructor should validate that required
strings are non-empty; update the explicit udf_filter(std::string source, void*
filter_data, float filtering_rate, std::string function_name) to check source
and function_name and throw a clear exception (e.g., std::invalid_argument) or
call an assert if either is empty, so invalid UDFs fail fast (this is the same
validation later done in make_cagra_sample_filter_udf_fragment but should be
performed at construction time in udf_filter).
In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh`:
- Around line 54-59: The code uses const_cast to assign filter.view().data()
into cagra_filter_data_storage because bitset_filter_data_t::bitset_ptr is a
non-const std::uint32_t*; replace this unsafe cast by making the payload and
kernel types explicitly const: change bitset_filter_data_t::bitset_ptr (and any
kernel parameter types and cagra_filter_data_storage template instantiations) to
const std::uint32_t* so the bitset is treated as immutable, and update any
affected functions (e.g., places constructing cagra_filter_data_storage in
cagra_filter_payload.cuh and kernel signatures that consume it) to accept the
const pointer, or alternatively add a clear comment next to
bitset_filter_data_t::bitset_ptr and the const raft::core::bitset_view usage
documenting that the bitset is immutable during search if you prefer to keep the
current type.
In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh`:
- Around line 34-42: The generated CUDA fragment in sample_filter_udf.cuh
currently injects fundamental type aliases via the oss R"(...)" block (e.g.,
int8_t, uint32_t, source_index_t) into the global namespace; modify the
generator to either (A) wrap these aliases in a dedicated namespace (e.g.,
namespace jit_types { using int8_t = signed char; ... using source_index_t =
<source_index_type>; }) and update any emitted code to reference
jit_types::source_index_t, or (B) include the proper header by emitting `#include`
<cstdint> at the top of the generated source and only emit the project-specific
alias source_index_t (still qualified or namespaced). Update uses of
source_index_t in the generated UDF to reference the chosen namespace if you
choose (A); keep changes localized to the oss string assembly in
sample_filter_udf.cuh.
In `@cpp/tests/neighbors/ann_cagra/test_filter_udf.cu`:
- Around line 106-170: Add edge-case unit tests using the existing
CagraUdfFilterTest fixture: create new TEST_P cases that override n_rows and
n_queries (set to 0 and 1 as needed) and invoke the search(...) helper with UDF
filters to validate behavior for empty dataset, empty queries, single-row
dataset, and requesting k larger than available passing rows; ensure you assert
on cagra_search_result.neighbors and .distances for expected shapes/values and
that no crashes or undefined-copy occur (use the existing search function, the
class CagraUdfFilterTest, and the cagra_search_result return value to drive the
checks).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 4d44a786-cdda-4dd1-9a2c-9ef7fcde1b54
📒 Files selected for processing (31)
cpp/include/cuvs/detail/jit_lto/common_fragments.hppcpp/include/cuvs/neighbors/cagra.hppcpp/include/cuvs/neighbors/common.hppcpp/src/neighbors/cagra.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.incpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hppcpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hppcpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.incpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hppcpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.incpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.incpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.incpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.incpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuhcpp/src/neighbors/detail/cagra/search_multi_kernel.cuhcpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuhcpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.incpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuhcpp/src/neighbors/detail/cagra/shared_launcher_jit.hppcpp/tests/CMakeLists.txtcpp/tests/neighbors/ann_cagra/test_filter_udf.cuexamples/cpp/CMakeLists.txtexamples/cpp/src/cagra_filter_udf_example.cufern/pages/neighbors/cagra.mdfern/pages/working_with_ann_indexes.md
💤 Files with no reviewable changes (1)
- cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh
| if (params.filtering_rate < 0.0) { | ||
| const float min_filtering_rate = 0.0f; | ||
| const float max_filtering_rate = 0.999f; | ||
| params_copy.filtering_rate = sample_filter.filtering_rate < 0.0f | ||
| ? 0.0f | ||
| : std::min(std::max(sample_filter.filtering_rate, | ||
| min_filtering_rate), | ||
| max_filtering_rate); |
There was a problem hiding this comment.
Reject NaN filtering_rate before clamping.
std::min/std::max do not sanitize NaN, so udf_filter.filtering_rate = NaN will propagate a NaN into params_copy.filtering_rate and then into downstream planner/kernel selection. Please validate finiteness here and either reject it or fall back to 0.0f.
Suggested fix
+#include <cmath>
+
if (params.filtering_rate < 0.0) {
const float min_filtering_rate = 0.0f;
const float max_filtering_rate = 0.999f;
+ RAFT_EXPECTS(std::isfinite(sample_filter.filtering_rate) ||
+ sample_filter.filtering_rate < 0.0f,
+ "UDF filtering_rate must be finite or negative");
params_copy.filtering_rate = sample_filter.filtering_rate < 0.0f
? 0.0f
- : std::min(std::max(sample_filter.filtering_rate,
- min_filtering_rate),
- max_filtering_rate);
+ : std::clamp(sample_filter.filtering_rate,
+ min_filtering_rate,
+ max_filtering_rate);
}As per coding guidelines, "Input validation must check for negative or invalid dimensions, null pointers, and invalid parameter combinations before GPU operations."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/src/neighbors/cagra.cuh` around lines 392 - 399, The code currently uses
std::min/std::max to clamp sample_filter.filtering_rate into
params_copy.filtering_rate but does not guard against NaN/infinite values, so a
NaN in sample_filter.filtering_rate will propagate; update the block (around
params.filtering_rate, params_copy.filtering_rate, and
sample_filter.filtering_rate) to first validate finiteness (e.g., std::isfinite
or !std::isnan) and if the value is not finite either reject it or set
params_copy.filtering_rate to a safe default (0.0f) before applying
std::min/std::max, and ensure any rejection path emits an appropriate
error/return per existing error handling policy.
| template <typename SourceIndexT> | ||
| inline constexpr const char* cagra_udf_source_index_type_name() | ||
| { | ||
| static_assert(std::is_same_v<SourceIndexT, std::uint32_t>, | ||
| "CAGRA filter UDFs currently support SourceIndexT = uint32_t only"); | ||
| return "uint32_t"; | ||
| } |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win
Document SourceIndexT=uint32_t restriction in public API.
The static_assert restricts UDF filters to SourceIndexT = uint32_t, but this limitation is not mentioned in the public udf_filter documentation. Users won't discover this constraint until JIT compilation fails.
Consider adding a note in the udf_filter Doxygen comment (cpp/include/cuvs/neighbors/common.hpp lines 620-633) stating that CAGRA currently provides source_index_t as uint32_t in the generated JIT fragment, matching the documentation on line 631-632.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh` around
lines 21 - 27, Update the public Doxygen for the udf_filter API to note that
CAGRA-generated JIT fragments use source_index_t == uint32_t (i.e., SourceIndexT
is currently restricted to uint32_t), so UDF filters must expect a uint32_t
source index; reference the internal symbol cagra_udf_source_index_type_name and
the typedef/source_index_t in the comment so users know this constraint up
front.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh`:
- Around line 59-115: The current cagra_device_payload_owner uses a single
shared state (struct state) with one device_payload and stream per hashed
payload (keyed by cagra_payload_hash), which is not stream-safe and leaks device
allocations in the static cache; change the design so the device copy is managed
per stream (or synchronized across streams) and make the cache evictable.
Concretely: replace the single device_payload/stream in state with a per-stream
mapping (e.g., map<cudaStream_t, device_entry> or use raft::stream_syncer to
serialize cross-stream access) so dev_ptr(cudaStream_t) either returns a device
copy created on that stream or waits on the creator’s event before reuse; also
ensure the static cache supports eviction (e.g., weak_ptrs or explicit erase
when state has no live owners) to avoid persistent allocations. Update
cagra_device_payload_owner::cagra_device_payload_owner, state::dev_ptr, and the
static cache usage to implement per-stream entries and proper stream
synchronization (or stream_syncer) and add eviction logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 1b9c2fd6-80f8-4ab2-beda-c8ec4c98709d
📒 Files selected for processing (4)
cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuhcpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuhcpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuhcpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh
🚧 Files skipped from review as they are similar to previous changes (2)
- cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh
- cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh
| #include "../../../sample_filter.cuh" // public filter types | ||
| #include "../../sample_filter_data.cuh" | ||
|
|
||
| #if !defined(__CUDACC_RTC__) |
There was a problem hiding this comment.
Are you including this header in runtime compilation? How are you doing so?
There was a problem hiding this comment.
Yes, indirectly. The generated JIT kernel sources include kernel_def.hpp / the JIT kernel headers, and those need the cagra_sample_filter parameter type in the kernel signatures. That’s why the small jit_lto_kernels/cagra_filter_payload.cuh header remains in the runtime-compiled/JIT side.
I moved the host-only ownership/cache/extraction logic out of this header into cagra_filter_payload.hpp, so the runtime-compiled side only sees the minimal kernel ABI struct.
| } | ||
|
|
||
| private: | ||
| mutable std::shared_ptr<state> state_; |
There was a problem hiding this comment.
Just wanted to point out that by using a shared_ptr that is initialized on construction, it becomes a requirement that state be initialized before threads branch off. Would that be the case if the user were handling threading and calling search? If not, would it make sense to make this static instead?
| template <typename SourceIndexT> | ||
| struct cagra_sample_filter_payload { | ||
| cagra_sample_filter<SourceIndexT> payload{}; | ||
| cagra_device_payload_owner<cagra_filter_data_storage<SourceIndexT>> storage_owner{}; |
There was a problem hiding this comment.
cagra_filter_data_storage is still of type bitset. So we will have to add an std::variant here to support pre-compiled filter types? In which case, why even have the owner, we could just go back to storing it by value for pre-compiled types.
| /// Host: find UDF compile/link metadata only. Query offsets stay in the runtime payload produced | ||
| /// by @ref extract_cagra_sample_filter and are applied before calling the linked sample_filter. | ||
| template <typename SampleFilterT> | ||
| const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter( |
There was a problem hiding this comment.
This function is a bit ugly. Is it possible to enforce that any and all CAGRA filters are wrapped by the outer SampleFilter type?
| { | ||
| const auto bf = extract_cagra_sample_filter<SourceIndexT>(sample_filter); | ||
| const uint32_t query_id_offset = bf.query_id_offset; | ||
| const auto filter_payload_owner = extract_cagra_sample_filter<SourceIndexT>(sample_filter); |
There was a problem hiding this comment.
Okay yep, so this is what I was talking about. filter_payload_owner is allocated on the stack here, which means it will be de-allocated at the end of the function call. For externally managed threading, this means that the device filter payload will be allocated and deallocated on every search call, instead of it being a no-op after the first one.
| cagra_sample_filter_payload<SourceIndexT> extract_cagra_sample_filter( | ||
| const SampleFilterT& sample_filter) | ||
| { | ||
| cagra_sample_filter_payload<SourceIndexT> out; |
There was a problem hiding this comment.
I think this should be declared static before returning
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp`:
- Around line 67-75: cagra_device_payload_owner::dev_ptr() can leave cache_[key]
with a partially-initialized state (device_payload set but ready_event invalid)
if any CUDA call after cudaMallocAsync fails; update state::dev_ptr() to perform
rollback on any CUDA failure: if cudaMallocAsync succeeded then on subsequent
error call cudaFreeAsync (or cudaFree if stream not available) to free
device_payload, destroy any created ready_event via cudaEventDestroy, reset
state members (device_payload = nullptr, ready_event = default/0, stream =
nullptr) and ensure the cache entry is removed or marked invalid (erase
cache_[key] or set a flag) so future calls do not use the leaked pointer and
invalid event; wrap CUDA calls with a try/catch or error-check path to execute
this cleanup whenever a RAFT_CUDA_TRY-like check reports failure.
In `@cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh`:
- Around line 623-624: The code binds filter_payload via
extract_cagra_sample_filter<SourceIndexT> to the caller stream but then lets the
persistent kernel consume it on the runner’s private stream, creating a race;
before assigning this->filter_payload (and before the persistent kernel launch
that reads query_id_offset), synchronize the request stream into the runner
stream using raft::stream_syncer (or an event handoff) so the device buffer
updates are visible and ordered; update the path around
extract_cagra_sample_filter<SourceIndexT>, filter_payload and the persistent
kernel launch to perform the stream handoff (use raft::stream_syncer to wait on
the request stream or record/wait an event) so the persistent kernel never reads
stale/torn UDF filter data.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 412ea828-d7c9-4e04-a22c-c9ef9d63f0bd
📒 Files selected for processing (7)
cpp/src/neighbors/detail/cagra/cagra_filter_payload.hppcpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuhcpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuhcpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuhcpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuhcpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp
💤 Files with no reviewable changes (1)
- cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh
🚧 Files skipped from review as they are similar to previous changes (4)
- cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh
- cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp
- cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh
- cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh
| if (device_payload == nullptr) { | ||
| RAFT_CUDA_TRY(cudaGetDevice(&device)); | ||
| RAFT_CUDA_TRY(cudaMallocAsync( | ||
| reinterpret_cast<void**>(&device_payload), sizeof(PayloadT), cuda_stream)); | ||
| RAFT_CUDA_TRY(cudaMemcpyAsync( | ||
| device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream)); | ||
| RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming)); | ||
| RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream)); | ||
| stream = cuda_stream; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
file=$(fd -p 'cagra_filter_payload.hpp' | head -n1)
echo '--- state::dev_ptr initialization block ---'
sed -n '64,80p' "$file"
echo
echo '--- cache insertion before initialization ---'
sed -n '116,129p' "$file"
echo
echo '--- fallible CUDA setup calls and nearby cleanup/catch ---'
rg -n 'cudaMallocAsync|cudaMemcpyAsync|cudaEventCreateWithFlags|cudaEventRecord|cudaFreeAsync|catch|device_payload|ready_event' "$file"Repository: rapidsai/cuvs
Length of output: 2714
🏁 Script executed:
#!/bin/bash
set -euo pipefail
file=$(fd -p 'cagra_filter_payload.hpp' | head -n1)
echo "=== File: $file ==="
echo
echo "---- state struct definition (around lines 40-90) ----"
sed -n '35,95p' "$file"
echo
echo "---- cache insertion (around lines 100-150) ----"
sed -n '95,150p' "$file"
echo
echo "---- get_cagra_device_payload (around lines 150-220) ----"
sed -n '150,230p' "$file"
echo
echo "---- RAFT_CUDA_TRY macro behavior (best-effort grep) ----"
rg -n "define RAFT_CUDA_TRY" -S . | head -n 20 || trueRepository: rapidsai/cuvs
Length of output: 7203
🏁 Script executed:
#!/bin/bash
set -euo pipefail
rg -n "define RAFT_CUDA_TRY\\(" -S . --hidden --no-ignore-vcsRepository: rapidsai/cuvs
Length of output: 1948
HIGH: Partially-initialized cached device payload on CUDA failure (missing rollback)
cagra_device_payload_owner::dev_ptr() inserts a new state into cache_[key] before state::dev_ptr() runs. If any state::dev_ptr() CUDA call fails after cudaMallocAsync succeeds (e.g., cudaMemcpyAsync, cudaEventCreateWithFlags, cudaEventRecord), the cached state can persist with device_payload != nullptr while ready_event remains default/invalid. Subsequent calls take the else path and call cudaStreamWaitEvent(cuda_stream, ready_event, 0), and the leaked allocation can persist for process lifetime since the cached entry is never removed.
Suggested fix
PayloadT* dev_ptr(cudaStream_t cuda_stream)
{
std::lock_guard<std::mutex> lock(mutex);
if (device_payload == nullptr) {
RAFT_CUDA_TRY(cudaGetDevice(&device));
- RAFT_CUDA_TRY(cudaMallocAsync(
- reinterpret_cast<void**>(&device_payload), sizeof(PayloadT), cuda_stream));
- RAFT_CUDA_TRY(cudaMemcpyAsync(
- device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream));
- RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming));
- RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream));
- stream = cuda_stream;
+ PayloadT* new_device_payload{nullptr};
+ cudaEvent_t new_ready_event{};
+ try {
+ RAFT_CUDA_TRY(cudaMallocAsync(
+ reinterpret_cast<void**>(&new_device_payload), sizeof(PayloadT), cuda_stream));
+ RAFT_CUDA_TRY(cudaMemcpyAsync(new_device_payload,
+ &host_payload,
+ sizeof(PayloadT),
+ cudaMemcpyHostToDevice,
+ cuda_stream));
+ RAFT_CUDA_TRY(cudaEventCreateWithFlags(&new_ready_event, cudaEventDisableTiming));
+ RAFT_CUDA_TRY(cudaEventRecord(new_ready_event, cuda_stream));
+ device_payload = new_device_payload;
+ ready_event = new_ready_event;
+ stream = cuda_stream;
+ } catch (...) {
+ if (new_ready_event != nullptr) {
+ RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(new_ready_event));
+ }
+ if (new_device_payload != nullptr) {
+ RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(new_device_payload, cuda_stream));
+ }
+ throw;
+ }
} else {
RAFT_CUDA_TRY(cudaStreamWaitEvent(cuda_stream, ready_event, 0));
}
return device_payload;
}🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp` around lines 67 -
75, cagra_device_payload_owner::dev_ptr() can leave cache_[key] with a
partially-initialized state (device_payload set but ready_event invalid) if any
CUDA call after cudaMallocAsync fails; update state::dev_ptr() to perform
rollback on any CUDA failure: if cudaMallocAsync succeeded then on subsequent
error call cudaFreeAsync (or cudaFree if stream not available) to free
device_payload, destroy any created ready_event via cudaEventDestroy, reset
state members (device_payload = nullptr, ready_event = default/0, stream =
nullptr) and ensure the cache entry is removed or marked invalid (erase
cache_[key] or set a flag) so future calls do not use the leaked pointer and
invalid event; wrap CUDA calls with a try/catch or error-check path to execute
this cleanup whenever a RAFT_CUDA_TRY-like check reports failure.
| this->filter_payload = extract_cagra_sample_filter<SourceIndexT>(sample_filter, stream); | ||
| const uint32_t query_id_offset = filter_payload.query_id_offset; |
There was a problem hiding this comment.
HIGH: Missing stream handoff for UDF filter context on the persistent path.
filter_payload is bound to the runner’s private CUDA stream here, but the PR explicitly allows filter_data to carry per-query runtime metadata. On cache hits, the persistent kernel can consume that device buffer without any ordering against the caller’s stream, so async updates to the UDF context can race and produce stale or torn predicate decisions. Have you considered syncing the request stream into the runner stream before the persistent kernel observes filter_payload (for example with raft::stream_syncer or an event handoff)? As per coding guidelines, use raft::stream_syncer for proper stream ordering in multi-threaded contexts.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh`
around lines 623 - 624, The code binds filter_payload via
extract_cagra_sample_filter<SourceIndexT> to the caller stream but then lets the
persistent kernel consume it on the runner’s private stream, creating a race;
before assigning this->filter_payload (and before the persistent kernel launch
that reads query_id_offset), synchronize the request stream into the runner
stream using raft::stream_syncer (or an event handoff) so the device buffer
updates are visible and ordered; update the path around
extract_cagra_sample_filter<SourceIndexT>, filter_payload and the persistent
kernel launch to perform the stream handoff (use raft::stream_syncer to wait on
the request stream or record/wait an event) so the persistent kernel never reads
stale/torn UDF filter data.
divyegala
left a comment
There was a problem hiding this comment.
Minor comments, fantastic PR!
|
|
||
| private: | ||
| mutable std::mutex cache_mutex_; | ||
| mutable std::unordered_map<cache_key, std::vector<std::shared_ptr<state>>, cache_key_hash> cache_; |
There was a problem hiding this comment.
minor nit: state does not need to be a shared_ptr anymore as the owner itself is static
| struct is_udf_filter<::cuvs::neighbors::filtering::udf_filter> : std::true_type {}; | ||
|
|
||
| template <typename SourceIndexT, typename FilterT> | ||
| cagra_filter_data_storage<SourceIndexT> make_cagra_filter_data_storage(const FilterT& filter) |
There was a problem hiding this comment.
Ok to address later: name this make_cagra_bitset_filter_storage so we can add factories for other pre-compiled types and don't need the specific alias anymore
| template <typename PayloadT> | ||
| void* get_cagra_device_payload(PayloadT payload, cudaStream_t stream) | ||
| { | ||
| static cagra_device_payload_owner<PayloadT> owner; | ||
| return owner.dev_ptr(payload, stream); | ||
| } |
There was a problem hiding this comment.
Call this in the factory? So the factory returns a perfectly valid device payload directly
| { | ||
| using DecayedFilter = std::decay_t<FilterT>; | ||
| if constexpr (is_bitset_filter<DecayedFilter>::value) { | ||
| out.filter_data = get_cagra_device_payload(make_cagra_filter_data_storage<SourceIndexT>(filter), |
There was a problem hiding this comment.
Would not need to call here then, going by above comment
This PR adds a first version of low-level JIT-LTO filter UDF support for CAGRA search in cuVS C++.
Users can provide CUDA device source for a predicate with this ABI:
The predicate returns
trueto allow a candidate andfalseto reject it.filter_datais not result data; it is an optional opaque device-accessible context pointer that lets the predicate read user metadata such as tenant ids, timestamps, language masks, ACL bitmaps, or query-specific thresholds. Results are still written to the normal CAGRAneighborsanddistancesoutputs.This gives CAGRA a path to support runtime metadata predicates without ahead-of-time template explosion, while keeping existing
noneandbitsetfilters on their static JIT-LTO fragment paths.What Changed
cuvs::neighbors::filtering::udf_filter.FilterType::UDF.udf_filter, includingfiltering_ratefallback behavior.About
filter_datafilter_datais the mechanism for passing runtime metadata into the device predicate. It is optional: simple predicates can ignore it or usenullptr.For example, this UDF needs no context:
For metadata filters, callers pass a device pointer to a user-defined context struct:
Then the UDF casts
filter_databack to that type:The pointer and anything it points to must be device-accessible and remain valid for the duration of the search. Future typed wrappers or expression builders could hide this cast, but the internal ABI can stay stable as:
Example
The standalone example builds one CAGRA index and runs three different UDF predicates over the same metadata context.
First, the caller defines a host-side context type. This same layout is repeated in the UDF source string so device code knows how to interpret
filter_data:The UDF source contains the device predicates. Each predicate receives the same
filter_datapointer and casts it tometadata_filter_context:The caller owns the metadata arrays. They must be copied to device memory before search:
Then the caller builds one device-resident context struct whose fields point at those device arrays:
Finally, each
udf_filterselects which device function to link by name. All three filters reuse the same source and the samefilter_datacontext:Example output:
Validation
Focused CAGRA UDF test passes across all CAGRA search algorithms:
Coverage includes:
filtering_rateUDF returns only accepted rowsBroader CAGRA regression set also passed:
Notes
This PR intentionally keeps the v1 API low-level: users provide CUDA source plus an optional
void*device context. Typed wrappers, expression builders, or write-once host/device ergonomics can be layered on later without changing the internal ABI.Filter UDFs are candidate-validity predicates only. They do not control CAGRA graph traversal, distance computation, PQ/VPQ internals, or result selection.
Benchmarks
Performance benchmarking is still TODO for this PR.
The expected regression risk for existing functionality is low because
noneandbitsetfilters still use static JIT-LTO fragments, and the UDF dynamic fragment path is only used forudf_filter. However, the CAGRA JIT kernel payload/signature plumbing changed from bitset-only payloads to a generic sample-filter payload, so we should still validate no-filter and bitset search performance againstmain.Proposed planned benchmark coverage:
mainmainThis should give us both existing-functionality regression coverage and a baseline for the new UDF path.