Skip to content

Introduce UDF Architecture#1804

Open
divyegala wants to merge 160 commits intorapidsai:release/26.04from
divyegala:ivf-flat-search-udf
Open

Introduce UDF Architecture#1804
divyegala wants to merge 160 commits intorapidsai:release/26.04from
divyegala:ivf-flat-search-udf

Conversation

@divyegala
Copy link
Member

@divyegala divyegala commented Feb 15, 2026

This PR introduces User-Defined-Functions supporting architecture in cuVS and uses JIT LTO to achieve it. The initial example is written for passing a metric UDF to IVF Flat search kernels.

When tested with native L2 metric and UDF L2 metric, we get native performance.
image

@divyegala divyegala linked an issue Mar 4, 2026 that may be closed by this pull request
@divyegala divyegala requested review from jinsolp and tarang-jain March 4, 2026 17:40
std::size_t size);

void registerNVRTCFragment(std::string const& key,
std::unique_ptr<char[]>&& program,
Copy link
Contributor

@mythrocks mythrocks Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't gotten far in this change yet, but I wonder why program is not a vector<char> instead of a unique_ptr<char[]>.
Edit: Or a std::string, really.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a particular reason. We need to use the C type char[] so it's just clearer IMO.

Comment on lines +3148 to +3151
if constexpr (std::is_same_v<T, uint8_t> && V > 1) {
auto diff = __vabsdiffu4(x.raw(), y.raw());
} else if constexpr (std::is_same_v<T, int8_t> && V > 1) {
auto diff = __vabsdiffs4(x.raw(), y.raw());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be retuning the diff here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch again!

Comment on lines +3132 to +3137
if constexpr (std::is_same_v<T, uint8_t> && V > 1) {
auto diff = __vabsdiffu4(x.raw(), y.raw());
return __dp4a(diff, diff, AccT{0});
} else if constexpr (std::is_same_v<T, int8_t> && V > 1) {
auto diff = __vabsdiffs4(x.raw(), y.raw());
return __dp4a(diff, diff, static_cast<uint32_t>(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should lead to correctness issues but why is the first one using AccT as the accumulator and the second one uint32_t? Can't we do return __dp4a(diff, diff, AccT{0}); for both?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Nice catch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh we can't, because __vabsdiffs4 returns a uint32_t.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh okay no worries

Copy link
Contributor

@achirkin achirkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @divyegala, thanks for the extensive work! From my side, a few small nitpicks, but also two major things to consider:

First, I think the enabling-nvrtc part can be split into a separate PR. Can it?

Second, the JIT code generation. I feel like having two versions of the code, one as a type-checked C++ header and the other one as a collections of std::string values will be hard to maintain in the long run. Which other code parsing/generation approaches have you considered?

  • How about generating/parsing AST programmatically? We could use the heavy-but-feature-complete clang set of tools, or something more lightweight, such as cppast at compile time to parse the header subject to JIT. Use C++ custom function attributes to label and find specifically functions to be saved as strings for compilation. Or perhaps even their dependencies?
  • As a low-effort workaround, maybe just annotate the functions and make a small python parser to extract the pieces of code?

Also, could you please describe (maybe in the PR description) why do we need to go all the way from source code for the UDF functions compiled via nvrtc rather than let the user pre-compile them into LTO IR at build time and use that at runtime?

point() = default;
__device__ __host__ explicit point(storage_type d) : data_(d) {}

__device__ __forceinline__ storage_type raw() const { return data_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use raft macros for these where appropriate:

Suggested change
__device__ __forceinline__ storage_type raw() const { return data_; }
RAFT_DEVICE_INLINE_FUNCTION storage_type raw() const { return data_; }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot use those here. We are not allowed to include any headers in the UDF strings.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write this in bold everywhere where it's relevant. To an unprepared viewer like me, there's no indication if this is the case and why. Another developer will see this with no commentary whatsoever and write the same style outside of the nvrtc/udf context.
This is becoming especially a problem in light of AI coding agents who learn the style from the existing project code and replicate it not-so-thoughtfully.

- if: cuda_major == "13"
then:
- libnvjitlink-dev
- cuda-nvrtc-dev
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the nvrtc introduction is shared between this PR and #1807 (all related changes in conda, cmake, and C++ headers). It looks to me like an important change set on its own. Could you please move it in a separate PR? This would reduce the diff of both PRs and will make the commit history in the main branch more granular (much more readable git blame).

* @tparam Veclen Vector length (1, 2, 4, 8, 16)
*/
template <typename T, typename AccT, int Veclen>
struct point {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've probably considered this already, but could you give a small explanation (perhaps right in the docstring for future reference): why can't we use TxN_t or IOType from raft/util/vectorized.cuh here? Either directly or as a wrapped-in carrier type? (adding missing accessors to raft I think is a reasonable idea too)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are pre-existing template types used by interleaved_scan_kernel to express its internal distance computation functions.

I would say it is orthogonal to this PR to update these types. Any decision to refactor interleaved_scan_kernel should come separately from the JIT/LTO feature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry I didn't look at ivf-flat code for a while and missed when this structure was introduced

};

// ============================================================
// Helper Operations - Deduce Veclen from point type!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this relate to?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is so that the user does not have to know the value of Veclen to write their UDF, it is just a convenience class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is so that the user does not have to know the value of Veclen to write their UDF, it is just a convenience class.

Could you please write just that in the file? When I stumbled upon this comment, it looked like an instruction to me or some sort of todo note :)

}

// ============================================================================
// String versions for JIT compilation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it's going to be hard to maintain. What if we put all definitions subject to JIT in a separate header and then write a small utility module that would read that header as text at compile time and save it to strings?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a little hard and would need us to plug it in with CMake, so that we can embed that generated string back into this header. That said, I do think that it will declutter this header nicely. Would you be okay with a follow-up?

Comment on lines +172 to +187
if (params.metric_udf.has_value()) {
std::string metric_udf = params.metric_udf.value();
// Add explicit template instantiation with actual types
metric_udf += "\ntemplate void cuvs::neighbors::ivf_flat::detail::compute_dist<";
metric_udf += std::to_string(Veclen);
metric_udf += ", ";
metric_udf += type_name<T>();
metric_udf += ", ";
metric_udf += type_name<AccT>();
metric_udf += ">(";
metric_udf += type_name<AccT>();
metric_udf += "&, ";
metric_udf += type_name<AccT>();
metric_udf += ", ";
metric_udf += type_name<AccT>();
metric_udf += ");\n";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're C++20 now, I think using std::format is preferable for being more concise:

metric_udf += fmt::format(
    "\ntemplate void cuvs::neighbors::ivf_flat::detail::compute_dist"
    "<{}, {}, {}>({}&, {}, {});\n",
    Veclen, type_name<T>(), type_name<AccT>(),
    type_name<AccT>(), type_name<AccT>(), type_name<AccT>());

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have fmt available?

@divyegala
Copy link
Member Author

Thanks for your review @achirkin!

First, I think the enabling-nvrtc part can be split into a separate PR. Can it?

I would prefer not to for a few reasons:

  1. No way to test the enabling-nvrtc part. This PR helps establish an e2e test through our algorithms.
  2. It has been tough to find timely reviews for my JIT LTO/UDF work, so I would like to maintain momentum on this PR while it has reviewers.

Second, the JIT code generation

I will attempt to answer all your questions related to the input at the same time.

At the present moment, we only have two options:

  1. Pass a string to nvrtc -> generate LTO-IR -> link with nvjitlink
  2. User passes LTO-IR -> link with nvjiitlink (we can enable this feature in a follow-up PR, I consider this a feature that would require some savviness from the user themselves)

Even though it appears that we are supporting two methods (header or string), we are actually only supporting method 1. The header is simply a convenience around avoiding using raw strings and letting our users figure out at compile time if their UDF has the right interface to work in our kernels.

Creating any other method to parse the UDF would involve a significant investment from our end. That said, we have other teams already working on a better UX - please check out this nvbug.

Keeping efforts of other teams in mind, I consider the current macro-style type checker to be a fair middle ground and low effort bar for us to help users enforce some compile time compatibility checks while we await further developments.

Copy link
Contributor

@achirkin achirkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. No way to test the enabling-nvrtc part. This PR helps establish an e2e test through our algorithms.

Is it really not possible to make a simple nvrtc test with a dummy kernel in our gtests? This is concerning.

  1. It has been tough to find timely reviews for my JIT LTO/UDF work, so I would like to maintain momentum on this PR while it has reviewers.

I'm sorry about that and I'm sorry I waited so long to join the review process. Yet I think it's no secret the size of a PR is the main contributor to deterring timely reviews.

On the topic of nvrtc strings. Thank you for the link, the RFE doc is a delight to read :) It outlines two common existing approaches at their extremes: (1) writing in plain strings vs (2) integrating a custom code generator in cmake. You chose (1) and it has its advantages over (2). In my comments I'm suggesting a middle ground as a slight improvement: why not writing an extremely simple parser to just copy the relevant parts of the header that you wrote as a means of compile-time validation? We don't strictly need to integrate it into the build system. We already do this in multiple places to generate template instances. We could, for example, "define" our own attribute and annotate the functions and structs to be passed to nvrtc like this:

[[cuvs::nvrtc_code_start("END_TOKEN")]]
template <typename T, typename AccT, int V>
__device__ __forceinline__ AccT product(point<T, AccT, V> x, point<T, AccT, V> y)
{
  return dot_product(x, y);
}
)
// END_TOKEN

And then extracting all strings would amount to a regex on [[cuvs::nvrtc_code_start\(".*"\)]] plus a simple parse till the declared token. With a proper parser like clang or cppast we could just filter and extract all annotated AST (with raft macros already applied), but even a simple python script goes a long way in making this easier to maintain.

point() = default;
__device__ __host__ explicit point(storage_type d) : data_(d) {}

__device__ __forceinline__ storage_type raw() const { return data_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write this in bold everywhere where it's relevant. To an unprepared viewer like me, there's no indication if this is the case and why. Another developer will see this with no commentary whatsoever and write the same style outside of the nvrtc/udf context.
This is becoming especially a problem in light of AI coding agents who learn the style from the existing project code and replicate it not-so-thoughtfully.

@achirkin achirkin requested review from achirkin and removed request for achirkin March 12, 2026 07:54
@achirkin achirkin dismissed their stale review March 12, 2026 08:00

Don't block the other great JIT-LTO/nvrtc work dependent on this PR over the design issues.

@divyegala divyegala changed the base branch from main to release/26.04 March 12, 2026 18:51
@divyegala
Copy link
Member Author

divyegala commented Mar 12, 2026

@achirkin

Is it really not possible to make a simple nvrtc test with a dummy kernel in our gtests? This is concerning.

Hmm, it should be actually. We would have to embed a test kernel in libcuvs.so though, which I don't feel great about adding testing components to the main library.

On the topic of nvrtc strings.

I like your proposed solution, and it definitely reads cleaner and more maintainable. It will take some time to figure out the right design to land on, which is why this feature is in the experimental namespace :). Can you please create an issue describing your idea? We can iterate on the design there and update in 26.06.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature request New feature or request non-breaking Introduces a non-breaking change

Projects

Development

Successfully merging this pull request may close these issues.

Introduce UDF Architecture and apply to interleaved_scan_kernel metric functions

8 participants