Skip to content

Add JIT-LTO based filter UDF support for CAGRA#2132

Open
dantegd wants to merge 12 commits into
rapidsai:mainfrom
dantegd:fea-udf-cagra-filter
Open

Add JIT-LTO based filter UDF support for CAGRA#2132
dantegd wants to merge 12 commits into
rapidsai:mainfrom
dantegd:fea-udf-cagra-filter

Conversation

@dantegd
Copy link
Copy Markdown
Member

@dantegd dantegd commented May 27, 2026

This PR adds a first version of low-level JIT-LTO filter UDF support for CAGRA search in cuVS C++.

Users can provide CUDA device source for a predicate with this ABI:

__device__ bool cuvs_filter_udf(uint32_t query_id,
                                source_index_t source_id,
                                void* filter_data);

The predicate returns true to allow a candidate and false to reject it. filter_data is not result data; it is an optional opaque device-accessible context pointer that lets the predicate read user metadata such as tenant ids, timestamps, language masks, ACL bitmaps, or query-specific thresholds. Results are still written to the normal CAGRA neighbors and distances outputs.

This gives CAGRA a path to support runtime metadata predicates without ahead-of-time template explosion, while keeping existing none and bitset filters on their static JIT-LTO fragment paths.

What Changed

  • Added cuvs::neighbors::filtering::udf_filter.
  • Added FilterType::UDF.
  • Added CAGRA dispatch for udf_filter, including filtering_rate fallback behavior.
  • Added a dynamic JIT-LTO sample-filter fragment generator for UDF source.
  • Reworked CAGRA JIT sample-filter payload plumbing from bitset-only payloads to a generic payload that supports:
    • no filter
    • bitset filter
    • UDF filter context pointer
    • query id offsets for batched/multi-kernel paths
  • Updated single-CTA, multi-CTA, and multi-kernel CAGRA JIT paths to call the linked sample filter uniformly.
  • Added docs for UDF behavior, caveats, and non-goals.
  • Added a standalone C++ example showing three metadata-style UDFs.

About filter_data

filter_data is the mechanism for passing runtime metadata into the device predicate. It is optional: simple predicates can ignore it or use nullptr.

For example, this UDF needs no context:

__device__ bool cuvs_filter_udf(uint32_t, source_index_t source_id, void*)
{
  return source_id >= 704;
}

For metadata filters, callers pass a device pointer to a user-defined context struct:

struct tenant_filter_context {
  const uint32_t* row_tenant_ids;
  const uint32_t* query_tenant_ids;
};

Then the UDF casts filter_data back to that type:

__device__ bool tenant_filter(uint32_t query_id,
                              source_index_t source_id,
                              void* filter_data)
{
  auto* ctx = static_cast<const tenant_filter_context*>(filter_data);
  return ctx->row_tenant_ids[source_id] == ctx->query_tenant_ids[query_id];
}

The pointer and anything it points to must be device-accessible and remain valid for the duration of the search. Future typed wrappers or expression builders could hide this cast, but the internal ABI can stay stable as:

bool predicate(uint32_t query_id, source_index_t source_id, void* filter_data);

Example

The standalone example builds one CAGRA index and runs three different UDF predicates over the same metadata context.

First, the caller defines a host-side context type. This same layout is repeated in the UDF source string so device code knows how to interpret filter_data:

struct metadata_filter_context {
  const uint32_t* row_tenant_ids;
  const int64_t* row_timestamps;
  const uint32_t* row_language_ids;
  const uint64_t* row_acl_masks;

  const uint32_t* query_tenant_ids;
  const int64_t* query_min_timestamps;
  const uint64_t* query_allowed_language_masks;
  const uint64_t* query_permission_masks;
};

The UDF source contains the device predicates. Each predicate receives the same filter_data pointer and casts it to metadata_filter_context:

std::string source = R"cpp(
struct metadata_filter_context {
  const uint32_t* row_tenant_ids;
  const int64_t* row_timestamps;
  const uint32_t* row_language_ids;
  const uint64_t* row_acl_masks;

  const uint32_t* query_tenant_ids;
  const int64_t* query_min_timestamps;
  const uint64_t* query_allowed_language_masks;
  const uint64_t* query_permission_masks;
};

__device__ bool tenant_filter(uint32_t query_id,
                              source_index_t source_id,
                              void* filter_data)
{
  auto* ctx = static_cast<const metadata_filter_context*>(filter_data);
  return ctx->row_tenant_ids[source_id] == ctx->query_tenant_ids[query_id];
}

__device__ bool timestamp_filter(uint32_t query_id,
                                 source_index_t source_id,
                                 void* filter_data)
{
  auto* ctx = static_cast<const metadata_filter_context*>(filter_data);
  return ctx->row_timestamps[source_id] >= ctx->query_min_timestamps[query_id];
}

__device__ bool language_acl_filter(uint32_t query_id,
                                    source_index_t source_id,
                                    void* filter_data)
{
  auto* ctx = static_cast<const metadata_filter_context*>(filter_data);
  const auto language_bit = uint64_t{1} << ctx->row_language_ids[source_id];
  const bool language_ok =
    (ctx->query_allowed_language_masks[query_id] & language_bit) != 0;
  const bool acl_ok =
    (ctx->row_acl_masks[source_id] & ctx->query_permission_masks[query_id]) != 0;
  return language_ok && acl_ok;
}
)cpp";

The caller owns the metadata arrays. They must be copied to device memory before search:

auto row_tenant_ids_device = raft::make_device_vector<uint32_t, int64_t>(res, n_rows);
auto row_timestamps_device = raft::make_device_vector<int64_t, int64_t>(res, n_rows);
auto row_language_ids_device = raft::make_device_vector<uint32_t, int64_t>(res, n_rows);
auto row_acl_masks_device = raft::make_device_vector<uint64_t, int64_t>(res, n_rows);

auto query_tenant_ids_device = raft::make_device_vector<uint32_t, int64_t>(res, n_queries);
auto query_min_timestamps_device = raft::make_device_vector<int64_t, int64_t>(res, n_queries);
auto query_allowed_language_masks_device =
  raft::make_device_vector<uint64_t, int64_t>(res, n_queries);
auto query_permission_masks_device =
  raft::make_device_vector<uint64_t, int64_t>(res, n_queries);

// Copy host metadata into these device arrays before launching search.

Then the caller builds one device-resident context struct whose fields point at those device arrays:

auto context_device =
  raft::make_device_vector<metadata_filter_context, int64_t>(res, 1);

metadata_filter_context host_context{
  row_tenant_ids_device.data_handle(),
  row_timestamps_device.data_handle(),
  row_language_ids_device.data_handle(),
  row_acl_masks_device.data_handle(),
  query_tenant_ids_device.data_handle(),
  query_min_timestamps_device.data_handle(),
  query_allowed_language_masks_device.data_handle(),
  query_permission_masks_device.data_handle()
};

// Copy the struct itself to device memory. The pointer to this device struct
// is what we pass as filter_data.
raft::copy(context_device.data_handle(),
           &host_context,
           1,
           raft::resource::get_cuda_stream(res));
raft::resource::sync_stream(res);

Finally, each udf_filter selects which device function to link by name. All three filters reuse the same source and the same filter_data context:

auto tenant_filter = cuvs::neighbors::filtering::udf_filter(
  source,
  context_device.data_handle(),  // filter_data
  0.75f,                         // estimated rejected fraction
  "tenant-filter-v1",            // cache key
  "tenant_filter");              // device function name

auto timestamp_filter = cuvs::neighbors::filtering::udf_filter(
  source,
  context_device.data_handle(),
  0.50f,
  "timestamp-filter-v1",
  "timestamp_filter");

auto language_acl_filter = cuvs::neighbors::filtering::udf_filter(
  source,
  context_device.data_handle(),
  0.875f,
  "language-acl-filter-v1",
  "language_acl_filter");

Example output:

Building CAGRA index
tenant_filter first query neighbors: 3364 3488 260 772 3292 3344 2052 660
timestamp_filter first query neighbors: 4070 3942 3364 3629 3797 3488 3292 3495
language_acl_filter first query neighbors: 3488 3344 2480 2520 2304 2176 1360 400
All CAGRA filter UDF examples produced valid filtered neighbors.

Validation

Focused CAGRA UDF test passes across all CAGRA search algorithms:

NEIGHBORS_ANN_CAGRA_FILTER_UDF_TEST

Coverage includes:

  • accept-all UDF matches no-filter results
  • reject-all UDF returns no valid dataset row ids
  • high filtering_rate UDF returns only accepted rows
  • invalid source throws during compile/search setup
  • repeated same-cache-key searches match
  • UDF threshold predicate matches equivalent bitset filter
  • query-specific tenant metadata works across SINGLE_CTA, MULTI_CTA, and MULTI_KERNEL

Broader CAGRA regression set also passed:

NEIGHBORS_ANN_CAGRA_(FILTER_UDF_TEST|FLOAT_UINT32|INT8_UINT32|UINT8_UINT32|HALF_UINT32|TEST_BUGS)

Notes

This PR intentionally keeps the v1 API low-level: users provide CUDA source plus an optional void* device context. Typed wrappers, expression builders, or write-once host/device ergonomics can be layered on later without changing the internal ABI.

Filter UDFs are candidate-validity predicates only. They do not control CAGRA graph traversal, distance computation, PQ/VPQ internals, or result selection.

Benchmarks

Performance benchmarking is still TODO for this PR.

The expected regression risk for existing functionality is low because none and bitset filters still use static JIT-LTO fragments, and the UDF dynamic fragment path is only used for udf_filter. However, the CAGRA JIT kernel payload/signature plumbing changed from bitset-only payloads to a generic sample-filter payload, so we should still validate no-filter and bitset search performance against main.

Proposed planned benchmark coverage:

  • no-filter CAGRA search vs main
  • bitset-filtered CAGRA search vs main
  • equivalent bitset predicate vs UDF predicate
  • first-call UDF compile/link latency
  • warm-cache repeated UDF search latency
  • representative CAGRA algorithms/configs, especially SINGLE_CTA, MULTI_CTA, and MULTI_KERNEL where applicable

This should give us both existing-functionality regression coverage and a baseline for the new UDF path.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 27, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@dantegd dantegd added feature request New feature or request non-breaking Introduces a non-breaking change labels May 27, 2026
@cjnolet cjnolet assigned cjnolet and dantegd and unassigned cjnolet May 27, 2026
@cjnolet cjnolet moved this to In Progress in Unstructured Data Processing May 27, 2026
Comment thread cpp/include/cuvs/neighbors/common.hpp Outdated
Comment on lines +641 to +644
/** Optional stable cache key for equivalent generated source. */
std::string cache_key;
/** Device function name to call from the generated CAGRA sample filter. */
std::string function_name = "cuvs_filter_udf";
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why expose these 2? We can just hide it within implementation details. As you noted previously in IVF Flat, the cache key can just be the string itself

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree on cache_key; it is an implementation detail and can be derived internally from the source/generated code.

For function_name, the intent was to support one source string with several predicates over the same metadata context, e.g. tenant_filter, timestamp_filter, and acl_filter, and let the user select which one to link without regenerating wrapper source each time. What do you think?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that is pretty smart, so you just pay compilation costs once.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache_key addressed in 292f7e8

/// Host/device payload for linked CAGRA sample filters plus query offset for wrapped filters.
template <typename SourceIndexT>
struct cagra_sample_filter {
cagra_bitset<SourceIndexT> bitset{};
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we hide this behind the opaque filter_data? bitset is an implementation detail of the bitset_filter

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored this into a generic cagra_filter_payload.cuh and renamed the embedded storage to filter_data_storage in 296d44c

Comment on lines +48 to +50
out.bitset.bitset_ptr = const_cast<std::uint32_t*>(bitset_view.data());
out.bitset.bitset_len = static_cast<SourceIndexT>(bitset_view.size());
out.bitset.original_nbits = static_cast<SourceIndexT>(bitset_view.get_original_nbits());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So going by my above comment, you would assign out.filter_data = cagra_bitset<SourceIndexT>{bitset_ptr, bitset_len, original_nbits} and then reinterpret this in the fragment

Comment thread cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh Outdated
Comment on lines +89 to +96
template <typename SourceIndexT>
__device__ __forceinline__ void* get_cagra_sample_filter_data(
cagra_sample_filter<SourceIndexT>& payload)
{
if (payload.filter_kind == cagra_filter_kind::udf) { return payload.filter_data; }
if (payload.filter_kind == cagra_filter_kind::bitset) { return &payload.bitset; }
return nullptr;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you make bitset opaque then you don't need this function

Comment thread cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh Outdated
Comment thread cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh
@dantegd dantegd marked this pull request as ready for review June 4, 2026 20:52
@dantegd dantegd requested review from a team as code owners June 4, 2026 20:52
@dantegd dantegd changed the title Add JIT-LTO filter UDF support for CAGRA Add JIT-LTO based filter UDF support for CAGRA Jun 4, 2026
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Jun 4, 2026

Worried about impact? Review this PR in Change Stack to explore blast radius before you approve or request changes.

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Summary by CodeRabbit

  • New Features

    • Added JIT-compiled, device-side user-defined CUDA filter UDFs for CAGRA search across single-/multi-CTA and multi-kernel paths, with runtime payloads, optional per-query context, and integrated launcher support.
  • Documentation

    • New docs and examples showing how to author, JIT, wire, and manage UDF filters, including filtering_rate fallback behavior and metadata-filter patterns.
  • Tests

    • New test suite validating UDF filter correctness, parity with bitset filters, and error handling.

Walkthrough

Adds UDF-based device filters to CAGRA: new FilterType::UDF and udf_filter API, cagra_sample_filter payload and device-owner cache, NVRTC UDF fragment factory, planner/launcher plumbing to accept fragments, kernel ABI/type migration to use sample-filter payloads, persistent-runner hashing updates, tests, example, and docs.

Changes

CAGRA UDF Filter Implementation

Layer / File(s) Summary
Public API: UDF filter type and tag
cpp/include/cuvs/detail/jit_lto/common_fragments.hpp, cpp/include/cuvs/neighbors/common.hpp, cpp/include/cuvs/neighbors/cagra.hpp
Adds tag_filter_udf, FilterType::UDF, new udf_filter struct, and documents filtering_rate fallback behavior.
Host/device payload and helpers
cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp, cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh
Introduces payload hashing, cagra_device_payload_owner cache, traits to detect bitset vs UDF filters, and helpers to fill/extract cagra_sample_filter for device dispatch.
UDF fragment generation and factories
cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh, cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp
Builds CUDA source for user UDFs, validates fields, compiles to UDFFatbinFragment via NVRTC, and threads optional fragments into JIT launcher build paths.
Planner registration and validation
cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp
add_sample_filter_device_function now accepts an optional UDFFatbinFragment and validates ownership based on the JIT tag before registering fragments.
Kernel typedefs and ABI updates
cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp, *.cu.in, *_jit.cuh
Replaces cagra_bitset with cagra_sample_filter across kernel typedefs, exported kernels, JIT wrappers, and in-kernel sample_filter call sites; supports query_id_offset indexing.
Search dispatch and persistent-runner integration
cpp/src/neighbors/cagra.cuh, cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh, cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh, cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh
Detects udf_filter in cagra::search, applies filtering_rate fallback, compiles or passes UDF fragments into launcher factories, updates kernel instantiations and dispatch to pass cagra_sample_filter payloads, and includes sample-filter-derived hash into persistent-runner keys.
Kernel instantiations
cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in, cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in
Adds udf_filter_t instantiations alongside existing none/bitset specializations.

Tests and Documentation

Layer / File(s) Summary
CAGRA UDF filter tests
cpp/tests/neighbors/ann_cagra/test_filter_udf.cu, cpp/tests/CMakeLists.txt
Adds gtest suite exercising accept/reject/threshold/determinism/invalid-source and tenant-context scenarios across single-CTA, multi-CTA, and multi-kernel modes; adds CMake test entry.
Example: metadata-driven UDF
examples/cpp/src/cagra_filter_udf_example.cu, examples/cpp/CMakeLists.txt
Example demonstrating device metadata context, UDF CUDA source, index build, and filtered search validation.
Docs: CAGRA filter UDF usage
fern/pages/neighbors/cagra.md, fern/pages/working_with_ann_indexes.md
Documents device-predicate signature, filter_data semantics, filtering_rate fallback rules, usage constraints, and links to examples.

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Possibly related PRs:

    • rapidsai/cuvs#1807: earlier work on sample-filter bitset dispatch and payload handling related to this UDF extension.
  • Suggested labels: C++, doc

  • Suggested reviewers:

    • KyleFromNVIDIA
    • divyegala
    • tfeher
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 1.92% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add JIT-LTO based filter UDF support for CAGRA' directly and specifically summarizes the main change: adding user-defined filter (UDF) support for CAGRA with JIT-LTO compilation.
Description check ✅ Passed The description is comprehensive and directly related to the changeset, explaining the feature addition, API details, implementation approach, examples, validation, and performance considerations.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
cpp/include/cuvs/neighbors/common.hpp (1)

634-658: 💤 Low value

Consider validating required UDF fields in the constructor.

The udf_filter constructor accepts source and function_name but does not validate that they are non-empty. While validation occurs later during JIT compilation (in make_cagra_sample_filter_udf_fragment), failing early in the constructor would provide clearer error messages and better user experience.

💡 Suggested defensive check
   explicit udf_filter(std::string source,
                       void* filter_data = nullptr,
                       float filtering_rate = -1.0f,
                       std::string function_name = "cuvs_filter_udf")
     : source(std::move(source)),
       filter_data(filter_data),
       filtering_rate(filtering_rate),
       function_name(std::move(function_name))
   {
+    // Optional: validate required fields early
+    // RAFT_EXPECTS(!this->source.empty(), "UDF source must not be empty");
+    // RAFT_EXPECTS(!this->function_name.empty(), "UDF function name must not be empty");
   }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/include/cuvs/neighbors/common.hpp` around lines 634 - 658, The udf_filter
constructor should validate that required strings are non-empty; update the
explicit udf_filter(std::string source, void* filter_data, float filtering_rate,
std::string function_name) to check source and function_name and throw a clear
exception (e.g., std::invalid_argument) or call an assert if either is empty, so
invalid UDFs fail fast (this is the same validation later done in
make_cagra_sample_filter_udf_fragment but should be performed at construction
time in udf_filter).
cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh (1)

34-42: 💤 Low value

Consider namespace isolation for generated type aliases.

Lines 34-40 inject type aliases (int8_t, uint32_t, etc.) into the global namespace of the generated CUDA source. While this is generally safe since NVRTC compiles each fragment in isolation, explicitly scoping these types or documenting the naming convention would prevent potential confusion if users' UDF code expects these types from <cstdint>.

The current approach works because:

  • Generated source is compiled separately via NVRTC
  • User UDF code will see these definitions

However, for robustness, consider either:

  1. Adding a comment explaining why global aliases are safe here
  2. Or including actual headers (#include <cstdint>) in the generated source
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh` around
lines 34 - 42, The generated CUDA fragment in sample_filter_udf.cuh currently
injects fundamental type aliases via the oss R"(...)" block (e.g., int8_t,
uint32_t, source_index_t) into the global namespace; modify the generator to
either (A) wrap these aliases in a dedicated namespace (e.g., namespace
jit_types { using int8_t = signed char; ... using source_index_t =
<source_index_type>; }) and update any emitted code to reference
jit_types::source_index_t, or (B) include the proper header by emitting `#include`
<cstdint> at the top of the generated source and only emit the project-specific
alias source_index_t (still qualified or namespaced). Update uses of
source_index_t in the generated UDF to reference the chosen namespace if you
choose (A); keep changes localized to the oss string assembly in
sample_filter_udf.cuh.
cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh (1)

54-59: ⚡ Quick win

Clarify const_cast usage for CAGRA bitset payload

bitset_filter_data_t::bitset_ptr is defined as std::uint32_t* (non-const), so the const_cast<std::uint32_t*>(bitset_view.data()) in cagra_filter_payload.cuh is required to assign filter.view().data() into that payload type; the device code then treats the bitset as read-only (builds a const raft::core::bitset_view and only calls test).

Optional: change bitset_filter_data_t::bitset_ptr (and corresponding kernel parameter types) to const std::uint32_t* to remove the cast, or document explicitly that the bitset is immutable during search.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh`
around lines 54 - 59, The code uses const_cast to assign filter.view().data()
into cagra_filter_data_storage because bitset_filter_data_t::bitset_ptr is a
non-const std::uint32_t*; replace this unsafe cast by making the payload and
kernel types explicitly const: change bitset_filter_data_t::bitset_ptr (and any
kernel parameter types and cagra_filter_data_storage template instantiations) to
const std::uint32_t* so the bitset is treated as immutable, and update any
affected functions (e.g., places constructing cagra_filter_data_storage in
cagra_filter_payload.cuh and kernel signatures that consume it) to accept the
const pointer, or alternatively add a clear comment next to
bitset_filter_data_t::bitset_ptr and the const raft::core::bitset_view usage
documenting that the bitset is immutable during search if you prefer to keep the
current type.
cpp/tests/neighbors/ann_cagra/test_filter_udf.cu (1)

106-170: ⚡ Quick win

Consider adding edge case tests for robustness.

The test fixture provides good coverage of the UDF filtering mechanism, but per coding guidelines, numerical correctness tests should validate edge cases. Consider adding tests for:

  • Empty dataset or query set (n_rows=0 or n_queries=0)
  • Single-element dataset (n_rows=1)
  • Requesting k neighbors when fewer than k rows pass the filter

These edge cases would help ensure the UDF filtering path handles boundary conditions consistently with other CAGRA filter types.

As per coding guidelines: Numerical correctness tests must validate edge cases: empty inputs, single elements, zero-norm vectors, identical points, and extreme values.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/tests/neighbors/ann_cagra/test_filter_udf.cu` around lines 106 - 170, Add
edge-case unit tests using the existing CagraUdfFilterTest fixture: create new
TEST_P cases that override n_rows and n_queries (set to 0 and 1 as needed) and
invoke the search(...) helper with UDF filters to validate behavior for empty
dataset, empty queries, single-row dataset, and requesting k larger than
available passing rows; ensure you assert on cagra_search_result.neighbors and
.distances for expected shapes/values and that no crashes or undefined-copy
occur (use the existing search function, the class CagraUdfFilterTest, and the
cagra_search_result return value to drive the checks).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/src/neighbors/cagra.cuh`:
- Around line 392-399: The code currently uses std::min/std::max to clamp
sample_filter.filtering_rate into params_copy.filtering_rate but does not guard
against NaN/infinite values, so a NaN in sample_filter.filtering_rate will
propagate; update the block (around params.filtering_rate,
params_copy.filtering_rate, and sample_filter.filtering_rate) to first validate
finiteness (e.g., std::isfinite or !std::isnan) and if the value is not finite
either reject it or set params_copy.filtering_rate to a safe default (0.0f)
before applying std::min/std::max, and ensure any rejection path emits an
appropriate error/return per existing error handling policy.

In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh`:
- Around line 21-27: Update the public Doxygen for the udf_filter API to note
that CAGRA-generated JIT fragments use source_index_t == uint32_t (i.e.,
SourceIndexT is currently restricted to uint32_t), so UDF filters must expect a
uint32_t source index; reference the internal symbol
cagra_udf_source_index_type_name and the typedef/source_index_t in the comment
so users know this constraint up front.

---

Nitpick comments:
In `@cpp/include/cuvs/neighbors/common.hpp`:
- Around line 634-658: The udf_filter constructor should validate that required
strings are non-empty; update the explicit udf_filter(std::string source, void*
filter_data, float filtering_rate, std::string function_name) to check source
and function_name and throw a clear exception (e.g., std::invalid_argument) or
call an assert if either is empty, so invalid UDFs fail fast (this is the same
validation later done in make_cagra_sample_filter_udf_fragment but should be
performed at construction time in udf_filter).

In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh`:
- Around line 54-59: The code uses const_cast to assign filter.view().data()
into cagra_filter_data_storage because bitset_filter_data_t::bitset_ptr is a
non-const std::uint32_t*; replace this unsafe cast by making the payload and
kernel types explicitly const: change bitset_filter_data_t::bitset_ptr (and any
kernel parameter types and cagra_filter_data_storage template instantiations) to
const std::uint32_t* so the bitset is treated as immutable, and update any
affected functions (e.g., places constructing cagra_filter_data_storage in
cagra_filter_payload.cuh and kernel signatures that consume it) to accept the
const pointer, or alternatively add a clear comment next to
bitset_filter_data_t::bitset_ptr and the const raft::core::bitset_view usage
documenting that the bitset is immutable during search if you prefer to keep the
current type.

In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh`:
- Around line 34-42: The generated CUDA fragment in sample_filter_udf.cuh
currently injects fundamental type aliases via the oss R"(...)" block (e.g.,
int8_t, uint32_t, source_index_t) into the global namespace; modify the
generator to either (A) wrap these aliases in a dedicated namespace (e.g.,
namespace jit_types { using int8_t = signed char; ... using source_index_t =
<source_index_type>; }) and update any emitted code to reference
jit_types::source_index_t, or (B) include the proper header by emitting `#include`
<cstdint> at the top of the generated source and only emit the project-specific
alias source_index_t (still qualified or namespaced). Update uses of
source_index_t in the generated UDF to reference the chosen namespace if you
choose (A); keep changes localized to the oss string assembly in
sample_filter_udf.cuh.

In `@cpp/tests/neighbors/ann_cagra/test_filter_udf.cu`:
- Around line 106-170: Add edge-case unit tests using the existing
CagraUdfFilterTest fixture: create new TEST_P cases that override n_rows and
n_queries (set to 0 and 1 as needed) and invoke the search(...) helper with UDF
filters to validate behavior for empty dataset, empty queries, single-row
dataset, and requesting k larger than available passing rows; ensure you assert
on cagra_search_result.neighbors and .distances for expected shapes/values and
that no crashes or undefined-copy occur (use the existing search function, the
class CagraUdfFilterTest, and the cagra_search_result return value to drive the
checks).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 4d44a786-cdda-4dd1-9a2c-9ef7fcde1b54

📥 Commits

Reviewing files that changed from the base of the PR and between 3035cd0 and 51716a6.

📒 Files selected for processing (31)
  • cpp/include/cuvs/detail/jit_lto/common_fragments.hpp
  • cpp/include/cuvs/neighbors/cagra.hpp
  • cpp/include/cuvs/neighbors/common.hpp
  • cpp/src/neighbors/cagra.cuh
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/kernel_def.hpp
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in
  • cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cu.in
  • cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh
  • cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/search_single_cta_inst.cu.in
  • cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp
  • cpp/tests/CMakeLists.txt
  • cpp/tests/neighbors/ann_cagra/test_filter_udf.cu
  • examples/cpp/CMakeLists.txt
  • examples/cpp/src/cagra_filter_udf_example.cu
  • fern/pages/neighbors/cagra.md
  • fern/pages/working_with_ann_indexes.md
💤 Files with no reviewable changes (1)
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_bitset.cuh

Comment thread cpp/src/neighbors/cagra.cuh Outdated
Comment on lines +392 to +399
if (params.filtering_rate < 0.0) {
const float min_filtering_rate = 0.0f;
const float max_filtering_rate = 0.999f;
params_copy.filtering_rate = sample_filter.filtering_rate < 0.0f
? 0.0f
: std::min(std::max(sample_filter.filtering_rate,
min_filtering_rate),
max_filtering_rate);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Reject NaN filtering_rate before clamping.

std::min/std::max do not sanitize NaN, so udf_filter.filtering_rate = NaN will propagate a NaN into params_copy.filtering_rate and then into downstream planner/kernel selection. Please validate finiteness here and either reject it or fall back to 0.0f.

Suggested fix
+#include <cmath>
+
   if (params.filtering_rate < 0.0) {
       const float min_filtering_rate = 0.0f;
       const float max_filtering_rate = 0.999f;
+      RAFT_EXPECTS(std::isfinite(sample_filter.filtering_rate) ||
+                     sample_filter.filtering_rate < 0.0f,
+                   "UDF filtering_rate must be finite or negative");
       params_copy.filtering_rate    = sample_filter.filtering_rate < 0.0f
                                       ? 0.0f
-                                      : std::min(std::max(sample_filter.filtering_rate,
-                                                          min_filtering_rate),
-                                                 max_filtering_rate);
+                                      : std::clamp(sample_filter.filtering_rate,
+                                                   min_filtering_rate,
+                                                   max_filtering_rate);
     }

As per coding guidelines, "Input validation must check for negative or invalid dimensions, null pointers, and invalid parameter combinations before GPU operations."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/neighbors/cagra.cuh` around lines 392 - 399, The code currently uses
std::min/std::max to clamp sample_filter.filtering_rate into
params_copy.filtering_rate but does not guard against NaN/infinite values, so a
NaN in sample_filter.filtering_rate will propagate; update the block (around
params.filtering_rate, params_copy.filtering_rate, and
sample_filter.filtering_rate) to first validate finiteness (e.g., std::isfinite
or !std::isnan) and if the value is not finite either reject it or set
params_copy.filtering_rate to a safe default (0.0f) before applying
std::min/std::max, and ensure any rejection path emits an appropriate
error/return per existing error handling policy.

Comment on lines +21 to +27
template <typename SourceIndexT>
inline constexpr const char* cagra_udf_source_index_type_name()
{
static_assert(std::is_same_v<SourceIndexT, std::uint32_t>,
"CAGRA filter UDFs currently support SourceIndexT = uint32_t only");
return "uint32_t";
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Document SourceIndexT=uint32_t restriction in public API.

The static_assert restricts UDF filters to SourceIndexT = uint32_t, but this limitation is not mentioned in the public udf_filter documentation. Users won't discover this constraint until JIT compilation fails.

Consider adding a note in the udf_filter Doxygen comment (cpp/include/cuvs/neighbors/common.hpp lines 620-633) stating that CAGRA currently provides source_index_t as uint32_t in the generated JIT fragment, matching the documentation on line 631-632.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh` around
lines 21 - 27, Update the public Doxygen for the udf_filter API to note that
CAGRA-generated JIT fragments use source_index_t == uint32_t (i.e., SourceIndexT
is currently restricted to uint32_t), so UDF filters must expect a uint32_t
source index; reference the internal symbol cagra_udf_source_index_type_name and
the typedef/source_index_t in the comment so users know this constraint up
front.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh`:
- Around line 59-115: The current cagra_device_payload_owner uses a single
shared state (struct state) with one device_payload and stream per hashed
payload (keyed by cagra_payload_hash), which is not stream-safe and leaks device
allocations in the static cache; change the design so the device copy is managed
per stream (or synchronized across streams) and make the cache evictable.
Concretely: replace the single device_payload/stream in state with a per-stream
mapping (e.g., map<cudaStream_t, device_entry> or use raft::stream_syncer to
serialize cross-stream access) so dev_ptr(cudaStream_t) either returns a device
copy created on that stream or waits on the creator’s event before reuse; also
ensure the static cache supports eviction (e.g., weak_ptrs or explicit erase
when state has no live owners) to avoid persistent allocations. Update
cagra_device_payload_owner::cagra_device_payload_owner, state::dev_ptr, and the
static cache usage to implement per-stream entries and proper stream
synchronization (or stream_syncer) and add eviction logic.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 1b9c2fd6-80f8-4ab2-beda-c8ec4c98709d

📥 Commits

Reviewing files that changed from the base of the PR and between 51716a6 and 3f021bb.

📒 Files selected for processing (4)
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh
  • cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh
🚧 Files skipped from review as they are similar to previous changes (2)
  • cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh

Comment thread cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh Outdated
#include "../../../sample_filter.cuh" // public filter types
#include "../../sample_filter_data.cuh"

#if !defined(__CUDACC_RTC__)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you including this header in runtime compilation? How are you doing so?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indirectly. The generated JIT kernel sources include kernel_def.hpp / the JIT kernel headers, and those need the cagra_sample_filter parameter type in the kernel signatures. That’s why the small jit_lto_kernels/cagra_filter_payload.cuh header remains in the runtime-compiled/JIT side.

I moved the host-only ownership/cache/extraction logic out of this header into cagra_filter_payload.hpp, so the runtime-compiled side only sees the minimal kernel ABI struct.

}

private:
mutable std::shared_ptr<state> state_;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to point out that by using a shared_ptr that is initialized on construction, it becomes a requirement that state be initialized before threads branch off. Would that be the case if the user were handling threading and calling search? If not, would it make sense to make this static instead?

template <typename SourceIndexT>
struct cagra_sample_filter_payload {
cagra_sample_filter<SourceIndexT> payload{};
cagra_device_payload_owner<cagra_filter_data_storage<SourceIndexT>> storage_owner{};
Copy link
Copy Markdown
Member

@divyegala divyegala Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cagra_filter_data_storage is still of type bitset. So we will have to add an std::variant here to support pre-compiled filter types? In which case, why even have the owner, we could just go back to storing it by value for pre-compiled types.

/// Host: find UDF compile/link metadata only. Query offsets stay in the runtime payload produced
/// by @ref extract_cagra_sample_filter and are applied before calling the linked sample_filter.
template <typename SampleFilterT>
const ::cuvs::neighbors::filtering::udf_filter* get_cagra_udf_filter(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is a bit ugly. Is it possible to enforce that any and all CAGRA filters are wrapped by the outer SampleFilter type?

{
const auto bf = extract_cagra_sample_filter<SourceIndexT>(sample_filter);
const uint32_t query_id_offset = bf.query_id_offset;
const auto filter_payload_owner = extract_cagra_sample_filter<SourceIndexT>(sample_filter);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay yep, so this is what I was talking about. filter_payload_owner is allocated on the stack here, which means it will be de-allocated at the end of the function call. For externally managed threading, this means that the device filter payload will be allocated and deallocated on every search call, instead of it being a no-op after the first one.

cagra_sample_filter_payload<SourceIndexT> extract_cagra_sample_filter(
const SampleFilterT& sample_filter)
{
cagra_sample_filter_payload<SourceIndexT> out;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be declared static before returning

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp`:
- Around line 67-75: cagra_device_payload_owner::dev_ptr() can leave cache_[key]
with a partially-initialized state (device_payload set but ready_event invalid)
if any CUDA call after cudaMallocAsync fails; update state::dev_ptr() to perform
rollback on any CUDA failure: if cudaMallocAsync succeeded then on subsequent
error call cudaFreeAsync (or cudaFree if stream not available) to free
device_payload, destroy any created ready_event via cudaEventDestroy, reset
state members (device_payload = nullptr, ready_event = default/0, stream =
nullptr) and ensure the cache entry is removed or marked invalid (erase
cache_[key] or set a flag) so future calls do not use the leaked pointer and
invalid event; wrap CUDA calls with a try/catch or error-check path to execute
this cleanup whenever a RAFT_CUDA_TRY-like check reports failure.

In `@cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh`:
- Around line 623-624: The code binds filter_payload via
extract_cagra_sample_filter<SourceIndexT> to the caller stream but then lets the
persistent kernel consume it on the runner’s private stream, creating a race;
before assigning this->filter_payload (and before the persistent kernel launch
that reads query_id_offset), synchronize the request stream into the runner
stream using raft::stream_syncer (or an event handoff) so the device buffer
updates are visible and ordered; update the path around
extract_cagra_sample_filter<SourceIndexT>, filter_payload and the persistent
kernel launch to perform the stream handoff (use raft::stream_syncer to wait on
the request stream or record/wait an event) so the persistent kernel never reads
stale/torn UDF filter data.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 412ea828-d7c9-4e04-a22c-c9ef9d63f0bd

📥 Commits

Reviewing files that changed from the base of the PR and between ee1475a and aa01651.

📒 Files selected for processing (7)
  • cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh
  • cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp
💤 Files with no reviewable changes (1)
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_filter_payload.cuh
🚧 Files skipped from review as they are similar to previous changes (4)
  • cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp
  • cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh
  • cpp/src/neighbors/detail/cagra/jit_lto_kernels/sample_filter_udf.cuh

Comment on lines +67 to +75
if (device_payload == nullptr) {
RAFT_CUDA_TRY(cudaGetDevice(&device));
RAFT_CUDA_TRY(cudaMallocAsync(
reinterpret_cast<void**>(&device_payload), sizeof(PayloadT), cuda_stream));
RAFT_CUDA_TRY(cudaMemcpyAsync(
device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream));
RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming));
RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream));
stream = cuda_stream;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

file=$(fd -p 'cagra_filter_payload.hpp' | head -n1)

echo '--- state::dev_ptr initialization block ---'
sed -n '64,80p' "$file"

echo
echo '--- cache insertion before initialization ---'
sed -n '116,129p' "$file"

echo
echo '--- fallible CUDA setup calls and nearby cleanup/catch ---'
rg -n 'cudaMallocAsync|cudaMemcpyAsync|cudaEventCreateWithFlags|cudaEventRecord|cudaFreeAsync|catch|device_payload|ready_event' "$file"

Repository: rapidsai/cuvs

Length of output: 2714


🏁 Script executed:

#!/bin/bash
set -euo pipefail

file=$(fd -p 'cagra_filter_payload.hpp' | head -n1)

echo "=== File: $file ==="

echo
echo "---- state struct definition (around lines 40-90) ----"
sed -n '35,95p' "$file"

echo
echo "---- cache insertion (around lines 100-150) ----"
sed -n '95,150p' "$file"

echo
echo "---- get_cagra_device_payload (around lines 150-220) ----"
sed -n '150,230p' "$file"

echo
echo "---- RAFT_CUDA_TRY macro behavior (best-effort grep) ----"
rg -n "define RAFT_CUDA_TRY" -S . | head -n 20 || true

Repository: rapidsai/cuvs

Length of output: 7203


🏁 Script executed:

#!/bin/bash
set -euo pipefail

rg -n "define RAFT_CUDA_TRY\\(" -S . --hidden --no-ignore-vcs

Repository: rapidsai/cuvs

Length of output: 1948


HIGH: Partially-initialized cached device payload on CUDA failure (missing rollback)

cagra_device_payload_owner::dev_ptr() inserts a new state into cache_[key] before state::dev_ptr() runs. If any state::dev_ptr() CUDA call fails after cudaMallocAsync succeeds (e.g., cudaMemcpyAsync, cudaEventCreateWithFlags, cudaEventRecord), the cached state can persist with device_payload != nullptr while ready_event remains default/invalid. Subsequent calls take the else path and call cudaStreamWaitEvent(cuda_stream, ready_event, 0), and the leaked allocation can persist for process lifetime since the cached entry is never removed.

Suggested fix
     PayloadT* dev_ptr(cudaStream_t cuda_stream)
     {
       std::lock_guard<std::mutex> lock(mutex);
       if (device_payload == nullptr) {
         RAFT_CUDA_TRY(cudaGetDevice(&device));
-        RAFT_CUDA_TRY(cudaMallocAsync(
-          reinterpret_cast<void**>(&device_payload), sizeof(PayloadT), cuda_stream));
-        RAFT_CUDA_TRY(cudaMemcpyAsync(
-          device_payload, &host_payload, sizeof(PayloadT), cudaMemcpyHostToDevice, cuda_stream));
-        RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming));
-        RAFT_CUDA_TRY(cudaEventRecord(ready_event, cuda_stream));
-        stream = cuda_stream;
+        PayloadT* new_device_payload{nullptr};
+        cudaEvent_t new_ready_event{};
+        try {
+          RAFT_CUDA_TRY(cudaMallocAsync(
+            reinterpret_cast<void**>(&new_device_payload), sizeof(PayloadT), cuda_stream));
+          RAFT_CUDA_TRY(cudaMemcpyAsync(new_device_payload,
+                                        &host_payload,
+                                        sizeof(PayloadT),
+                                        cudaMemcpyHostToDevice,
+                                        cuda_stream));
+          RAFT_CUDA_TRY(cudaEventCreateWithFlags(&new_ready_event, cudaEventDisableTiming));
+          RAFT_CUDA_TRY(cudaEventRecord(new_ready_event, cuda_stream));
+          device_payload = new_device_payload;
+          ready_event    = new_ready_event;
+          stream         = cuda_stream;
+        } catch (...) {
+          if (new_ready_event != nullptr) {
+            RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(new_ready_event));
+          }
+          if (new_device_payload != nullptr) {
+            RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(new_device_payload, cuda_stream));
+          }
+          throw;
+        }
       } else {
         RAFT_CUDA_TRY(cudaStreamWaitEvent(cuda_stream, ready_event, 0));
       }
       return device_payload;
     }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/neighbors/detail/cagra/cagra_filter_payload.hpp` around lines 67 -
75, cagra_device_payload_owner::dev_ptr() can leave cache_[key] with a
partially-initialized state (device_payload set but ready_event invalid) if any
CUDA call after cudaMallocAsync fails; update state::dev_ptr() to perform
rollback on any CUDA failure: if cudaMallocAsync succeeded then on subsequent
error call cudaFreeAsync (or cudaFree if stream not available) to free
device_payload, destroy any created ready_event via cudaEventDestroy, reset
state members (device_payload = nullptr, ready_event = default/0, stream =
nullptr) and ensure the cache entry is removed or marked invalid (erase
cache_[key] or set a flag) so future calls do not use the leaked pointer and
invalid event; wrap CUDA calls with a try/catch or error-check path to execute
this cleanup whenever a RAFT_CUDA_TRY-like check reports failure.

Comment on lines +623 to +624
this->filter_payload = extract_cagra_sample_filter<SourceIndexT>(sample_filter, stream);
const uint32_t query_id_offset = filter_payload.query_id_offset;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

HIGH: Missing stream handoff for UDF filter context on the persistent path.

filter_payload is bound to the runner’s private CUDA stream here, but the PR explicitly allows filter_data to carry per-query runtime metadata. On cache hits, the persistent kernel can consume that device buffer without any ordering against the caller’s stream, so async updates to the UDF context can race and produce stale or torn predicate decisions. Have you considered syncing the request stream into the runner stream before the persistent kernel observes filter_payload (for example with raft::stream_syncer or an event handoff)? As per coding guidelines, use raft::stream_syncer for proper stream ordering in multi-threaded contexts.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh`
around lines 623 - 624, The code binds filter_payload via
extract_cagra_sample_filter<SourceIndexT> to the caller stream but then lets the
persistent kernel consume it on the runner’s private stream, creating a race;
before assigning this->filter_payload (and before the persistent kernel launch
that reads query_id_offset), synchronize the request stream into the runner
stream using raft::stream_syncer (or an event handoff) so the device buffer
updates are visible and ordered; update the path around
extract_cagra_sample_filter<SourceIndexT>, filter_payload and the persistent
kernel launch to perform the stream handoff (use raft::stream_syncer to wait on
the request stream or record/wait an event) so the persistent kernel never reads
stale/torn UDF filter data.

Copy link
Copy Markdown
Member

@divyegala divyegala left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments, fantastic PR!


private:
mutable std::mutex cache_mutex_;
mutable std::unordered_map<cache_key, std::vector<std::shared_ptr<state>>, cache_key_hash> cache_;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: state does not need to be a shared_ptr anymore as the owner itself is static

struct is_udf_filter<::cuvs::neighbors::filtering::udf_filter> : std::true_type {};

template <typename SourceIndexT, typename FilterT>
cagra_filter_data_storage<SourceIndexT> make_cagra_filter_data_storage(const FilterT& filter)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok to address later: name this make_cagra_bitset_filter_storage so we can add factories for other pre-compiled types and don't need the specific alias anymore

Comment on lines +162 to +167
template <typename PayloadT>
void* get_cagra_device_payload(PayloadT payload, cudaStream_t stream)
{
static cagra_device_payload_owner<PayloadT> owner;
return owner.dev_ptr(payload, stream);
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this in the factory? So the factory returns a perfectly valid device payload directly

{
using DecayedFilter = std::decay_t<FilterT>;
if constexpr (is_bitset_filter<DecayedFilter>::value) {
out.filter_data = get_cagra_device_payload(make_cagra_filter_data_storage<SourceIndexT>(filter),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would not need to call here then, going by above comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature request New feature or request non-breaking Introduces a non-breaking change

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

3 participants