From 28f6e1dda742c8f557b109b313f2bccb711ea405 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Thu, 21 May 2026 16:57:53 -0500 Subject: [PATCH 1/2] FEA v1 of Python device UDF support for IVF Flat --- c/include/cuvs/core/all.h | 1 + c/include/cuvs/core/device_udf.h | 74 ++++ c/include/cuvs/neighbors/ivf_flat.h | 6 + c/src/neighbors/ivf_flat.cpp | 30 +- cpp/include/cuvs/core/device_udf.h | 74 ++++ cpp/include/cuvs/core/device_udf.hpp | 101 +++++ .../cuvs/detail/jit_lto/FragmentEntry.hpp | 8 +- cpp/include/cuvs/neighbors/ivf_flat.hpp | 13 +- cpp/src/detail/jit_lto/FragmentEntry.cpp | 8 + .../jit_lto_kernels/device_functions.cuh | 6 +- .../jit_lto_kernels/interleaved_scan_impl.cuh | 4 +- .../interleaved_scan_kernel.cu.in | 2 + .../detail/jit_lto_kernels/kernel_def.hpp | 1 + .../load_and_compute_dist_impl.cuh | 42 +- .../load_and_compute_dist_kernel.cu.in | 6 +- .../jit_lto_kernels/metric_kernel.cu.in | 10 + ...vf_flat_interleaved_scan_explicit_inst.cuh | 5 +- .../ivf_flat_interleaved_scan_ext.cuh | 9 +- .../ivf_flat_interleaved_scan_jit.cuh | 104 ++++- .../neighbors/ivf_flat/ivf_flat_search.cuh | 13 +- cpp/src/neighbors/refine/refine_device.cuh | 1 + cpp/tests/CMakeLists.txt | 1 + cpp/tests/neighbors/ann_ivf_flat/test_udf.cu | 307 ++++++++++++++ python/cuvs/cuvs/_lib/__init__.py | 2 + python/cuvs/cuvs/_lib/device_udf.py | 167 ++++++++ .../cuvs/cuvs/_lib/udf_backends/__init__.py | 2 + .../cuvs/_lib/udf_backends/numba_cuda_mlir.py | 135 ++++++ python/cuvs/cuvs/_lib/udf_validation.py | 297 +++++++++++++ .../cuvs/cuvs/neighbors/ivf_flat/__init__.py | 3 + python/cuvs/cuvs/neighbors/ivf_flat/_udf.py | 396 ++++++++++++++++++ .../cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pxd | 46 +- .../cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx | 119 +++++- python/cuvs/cuvs/tests/test_device_udf.py | 189 +++++++++ python/cuvs/cuvs/tests/test_ivf_flat.py | 201 +++++++++ .../cuvs/cuvs/tests/test_ivf_flat_metric.py | 84 ++++ .../tests/test_numba_cuda_mlir_backend.py | 118 ++++++ python/cuvs/cuvs/tests/test_udf_validation.py | 164 ++++++++ 37 files changed, 2706 insertions(+), 43 deletions(-) create mode 100644 c/include/cuvs/core/device_udf.h create mode 100644 cpp/include/cuvs/core/device_udf.h create mode 100644 cpp/include/cuvs/core/device_udf.hpp create mode 100644 python/cuvs/cuvs/_lib/__init__.py create mode 100644 python/cuvs/cuvs/_lib/device_udf.py create mode 100644 python/cuvs/cuvs/_lib/udf_backends/__init__.py create mode 100644 python/cuvs/cuvs/_lib/udf_backends/numba_cuda_mlir.py create mode 100644 python/cuvs/cuvs/_lib/udf_validation.py create mode 100644 python/cuvs/cuvs/neighbors/ivf_flat/_udf.py create mode 100644 python/cuvs/cuvs/tests/test_device_udf.py create mode 100644 python/cuvs/cuvs/tests/test_ivf_flat_metric.py create mode 100644 python/cuvs/cuvs/tests/test_numba_cuda_mlir_backend.py create mode 100644 python/cuvs/cuvs/tests/test_udf_validation.py diff --git a/c/include/cuvs/core/all.h b/c/include/cuvs/core/all.h index 6834f1b095..a1d3133daf 100644 --- a/c/include/cuvs/core/all.h +++ b/c/include/cuvs/core/all.h @@ -10,6 +10,7 @@ #include #include +#include #include #include diff --git a/c/include/cuvs/core/device_udf.h b/c/include/cuvs/core/device_udf.h new file mode 100644 index 0000000000..0043533bfc --- /dev/null +++ b/c/include/cuvs/core/device_udf.h @@ -0,0 +1,74 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once +#ifndef CUVS_CORE_DEVICE_UDF_H +#define CUVS_CORE_DEVICE_UDF_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** Device UDF payload representation. */ +typedef enum cuvsDeviceUDFPayloadKind { + CUVS_DEVICE_UDF_PAYLOAD_LTOIR = 1, + CUVS_DEVICE_UDF_PAYLOAD_CUDA_SOURCE = 2, +} cuvsDeviceUDFPayloadKind; + +/** Capture flags for cuvsUDFCapture. */ +enum { CUVS_UDF_CAPTURE_READONLY = 1u }; + +/** Borrowed capture descriptor for a device UDF. */ +typedef struct cuvsUDFCapture { + /** Capture name as used by the frontend, e.g. "weights". */ + const char* name; + /** Logical dtype string, e.g. "float32". */ + const char* dtype; + /** Optional shape array of length ndim. */ + const int64_t* shape; + /** Optional strides array of length ndim, in bytes. Null means contiguous/default. */ + const int64_t* strides; + /** Number of dimensions in shape/strides. */ + int32_t ndim; + /** CUDA device ordinal for the capture allocation. */ + int32_t device_id; + /** CUDA device pointer for the capture allocation. */ + uintptr_t pointer; + /** Bitmask of CUVS_UDF_CAPTURE_* flags. */ + uint32_t flags; +} cuvsUDFCapture; + +/** Borrowed device UDF descriptor. Payload and captures are copied by consumers. */ +typedef struct cuvsDeviceUDF { + /** ABI identifier, e.g. "rapids.cuvs.ivf_flat.metric.v1". */ + const char* abi; + /** Payload representation. */ + cuvsDeviceUDFPayloadKind payload_kind; + /** Borrowed payload bytes. */ + const void* payload; + /** Number of bytes in payload. */ + size_t payload_size; + /** Optional externally visible device symbol in payload. */ + const char* symbol_name; + /** Borrowed capture descriptors. */ + const cuvsUDFCapture* captures; + /** Number of capture descriptors. */ + size_t n_captures; + /** Cache key used to identify this UDF artifact. */ + const char* cache_key; + /** Reserved flags; must be zero for now. */ + uint32_t flags; +} cuvsDeviceUDF; + +typedef const cuvsDeviceUDF* cuvsDeviceUDF_t; + +#ifdef __cplusplus +} +#endif + +#endif // CUVS_CORE_DEVICE_UDF_H diff --git a/c/include/cuvs/neighbors/ivf_flat.h b/c/include/cuvs/neighbors/ivf_flat.h index e1531eb091..597c5e655f 100644 --- a/c/include/cuvs/neighbors/ivf_flat.h +++ b/c/include/cuvs/neighbors/ivf_flat.h @@ -6,10 +6,12 @@ #pragma once #include +#include #include #include #include #include +#include #include #include @@ -103,6 +105,8 @@ CUVS_EXPORT cuvsError_t cuvsIvfFlatIndexParamsDestroy(cuvsIvfFlatIndexParams_t i struct cuvsIvfFlatSearchParams { /** The number of clusters to search. */ uint32_t n_probes; + /** Optional borrowed device UDF descriptor for a custom metric. */ + cuvsDeviceUDF_t metric_udf; }; typedef struct cuvsIvfFlatSearchParams* cuvsIvfFlatSearchParams_t; @@ -201,6 +205,7 @@ CUVS_EXPORT cuvsError_t cuvsIvfFlatIndexGetCenters(cuvsIvfFlatIndex_t index, DLM * * @code {.c} * #include + * #include * #include * * // Create cuvsResources_t @@ -257,6 +262,7 @@ CUVS_EXPORT cuvsError_t cuvsIvfFlatBuild(cuvsResources_t res, * * @code {.c} * #include + * #include * #include * * // Create cuvsResources_t diff --git a/c/src/neighbors/ivf_flat.cpp b/c/src/neighbors/ivf_flat.cpp index 56a3088e89..ec8da5fdde 100644 --- a/c/src/neighbors/ivf_flat.cpp +++ b/c/src/neighbors/ivf_flat.cpp @@ -5,6 +5,8 @@ */ #include +#include + #include #include @@ -21,6 +23,7 @@ #include "../core/interop.hpp" #include +#include namespace cuvs::neighbors::ivf_flat { void convert_c_index_params(cuvsIvfFlatIndexParams params, @@ -39,6 +42,28 @@ void convert_c_search_params(cuvsIvfFlatSearchParams params, cuvs::neighbors::ivf_flat::search_params* out) { out->n_probes = params.n_probes; + + if (params.metric_udf == nullptr) { return; } + + auto metric_udf = cuvs::jit::make_device_udf(*params.metric_udf); + RAFT_EXPECTS(metric_udf.abi == "rapids.cuvs.ivf_flat.metric.v1", + "Unsupported IVF Flat metric UDF ABI: %s", + metric_udf.abi.c_str()); + + if (metric_udf.payload_kind == cuvs::jit::device_udf_payload_kind::cuda_source) { + RAFT_EXPECTS(metric_udf.captures.empty(), + "IVF Flat CUDA source metric UDF currently does not support captures"); + out->metric_udf = std::string{reinterpret_cast(metric_udf.payload.data()), + metric_udf.payload.size()}; + return; + } + + RAFT_EXPECTS(metric_udf.payload_kind == cuvs::jit::device_udf_payload_kind::ltoir, + "IVF Flat metric UDF currently requires an LTO-IR or CUDA source payload"); + RAFT_EXPECTS(metric_udf.captures.size() <= 1, + "IVF Flat metric UDF currently supports at most one capture"); + + out->metric_ltoir_udf = std::move(metric_udf); } } // namespace cuvs::neighbors::ivf_flat @@ -282,8 +307,9 @@ extern "C" cuvsError_t cuvsIvfFlatIndexParamsDestroy(cuvsIvfFlatIndexParams_t pa extern "C" cuvsError_t cuvsIvfFlatSearchParamsCreate(cuvsIvfFlatSearchParams_t* params) { - return cuvs::core::translate_exceptions( - [=] { *params = new cuvsIvfFlatSearchParams{.n_probes = 20}; }); + return cuvs::core::translate_exceptions([=] { + *params = new cuvsIvfFlatSearchParams{.n_probes = 20, .metric_udf = nullptr}; + }); } extern "C" cuvsError_t cuvsIvfFlatSearchParamsDestroy(cuvsIvfFlatSearchParams_t params) diff --git a/cpp/include/cuvs/core/device_udf.h b/cpp/include/cuvs/core/device_udf.h new file mode 100644 index 0000000000..0043533bfc --- /dev/null +++ b/cpp/include/cuvs/core/device_udf.h @@ -0,0 +1,74 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once +#ifndef CUVS_CORE_DEVICE_UDF_H +#define CUVS_CORE_DEVICE_UDF_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** Device UDF payload representation. */ +typedef enum cuvsDeviceUDFPayloadKind { + CUVS_DEVICE_UDF_PAYLOAD_LTOIR = 1, + CUVS_DEVICE_UDF_PAYLOAD_CUDA_SOURCE = 2, +} cuvsDeviceUDFPayloadKind; + +/** Capture flags for cuvsUDFCapture. */ +enum { CUVS_UDF_CAPTURE_READONLY = 1u }; + +/** Borrowed capture descriptor for a device UDF. */ +typedef struct cuvsUDFCapture { + /** Capture name as used by the frontend, e.g. "weights". */ + const char* name; + /** Logical dtype string, e.g. "float32". */ + const char* dtype; + /** Optional shape array of length ndim. */ + const int64_t* shape; + /** Optional strides array of length ndim, in bytes. Null means contiguous/default. */ + const int64_t* strides; + /** Number of dimensions in shape/strides. */ + int32_t ndim; + /** CUDA device ordinal for the capture allocation. */ + int32_t device_id; + /** CUDA device pointer for the capture allocation. */ + uintptr_t pointer; + /** Bitmask of CUVS_UDF_CAPTURE_* flags. */ + uint32_t flags; +} cuvsUDFCapture; + +/** Borrowed device UDF descriptor. Payload and captures are copied by consumers. */ +typedef struct cuvsDeviceUDF { + /** ABI identifier, e.g. "rapids.cuvs.ivf_flat.metric.v1". */ + const char* abi; + /** Payload representation. */ + cuvsDeviceUDFPayloadKind payload_kind; + /** Borrowed payload bytes. */ + const void* payload; + /** Number of bytes in payload. */ + size_t payload_size; + /** Optional externally visible device symbol in payload. */ + const char* symbol_name; + /** Borrowed capture descriptors. */ + const cuvsUDFCapture* captures; + /** Number of capture descriptors. */ + size_t n_captures; + /** Cache key used to identify this UDF artifact. */ + const char* cache_key; + /** Reserved flags; must be zero for now. */ + uint32_t flags; +} cuvsDeviceUDF; + +typedef const cuvsDeviceUDF* cuvsDeviceUDF_t; + +#ifdef __cplusplus +} +#endif + +#endif // CUVS_CORE_DEVICE_UDF_H diff --git a/cpp/include/cuvs/core/device_udf.hpp b/cpp/include/cuvs/core/device_udf.hpp new file mode 100644 index 0000000000..b6b4143ae8 --- /dev/null +++ b/cpp/include/cuvs/core/device_udf.hpp @@ -0,0 +1,101 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +namespace cuvs::jit { + +enum class device_udf_payload_kind { ltoir, cuda_source }; + +struct udf_capture { + std::string name; + std::string dtype; + std::vector shape; + std::vector strides; + int32_t device_id = 0; + std::uintptr_t pointer = 0; + bool readonly = true; +}; + +struct device_udf { + std::string abi; + device_udf_payload_kind payload_kind = device_udf_payload_kind::ltoir; + std::vector payload; + std::string symbol_name; + std::string cache_key; + std::vector captures; +}; + +using ltoir_udf = device_udf; + +inline device_udf_payload_kind payload_kind_from_c(cuvsDeviceUDFPayloadKind kind) +{ + switch (kind) { + case CUVS_DEVICE_UDF_PAYLOAD_LTOIR: return device_udf_payload_kind::ltoir; + case CUVS_DEVICE_UDF_PAYLOAD_CUDA_SOURCE: return device_udf_payload_kind::cuda_source; + } + RAFT_FAIL("Unsupported cuVS device UDF payload kind: %d", static_cast(kind)); + return device_udf_payload_kind::ltoir; +} + +inline device_udf make_device_udf(cuvsDeviceUDF const& desc) +{ + RAFT_EXPECTS(desc.abi != nullptr, "device UDF abi must not be null"); + RAFT_EXPECTS(desc.payload != nullptr, "device UDF payload must not be null"); + RAFT_EXPECTS(desc.payload_size > 0, "device UDF payload_size must be non-zero"); + RAFT_EXPECTS(desc.symbol_name != nullptr, "device UDF symbol_name must not be null"); + RAFT_EXPECTS(desc.cache_key != nullptr, "device UDF cache_key must not be null"); + RAFT_EXPECTS(desc.flags == 0, "device UDF flags must be zero"); + RAFT_EXPECTS(desc.n_captures == 0 || desc.captures != nullptr, + "device UDF captures must not be null when n_captures is non-zero"); + + auto const* payload_begin = static_cast(desc.payload); + auto out = device_udf{.abi = std::string{desc.abi}, + .payload_kind = payload_kind_from_c(desc.payload_kind), + .payload = std::vector{ + payload_begin, payload_begin + desc.payload_size}, + .symbol_name = std::string{desc.symbol_name}, + .cache_key = std::string{desc.cache_key}}; + + out.captures.reserve(desc.n_captures); + for (size_t i = 0; i < desc.n_captures; ++i) { + auto const& capture = desc.captures[i]; + RAFT_EXPECTS(capture.name != nullptr, "device UDF capture name must not be null"); + RAFT_EXPECTS(capture.dtype != nullptr, "device UDF capture dtype must not be null"); + RAFT_EXPECTS(capture.ndim >= 0, "device UDF capture ndim must be non-negative"); + RAFT_EXPECTS(capture.ndim == 0 || capture.shape != nullptr, + "device UDF capture shape must not be null when ndim is non-zero"); + RAFT_EXPECTS(capture.pointer != 0, "device UDF capture pointer must not be zero"); + RAFT_EXPECTS((capture.flags & ~CUVS_UDF_CAPTURE_READONLY) == 0, + "device UDF capture has unsupported flags"); + + auto next = udf_capture{.name = std::string{capture.name}, + .dtype = std::string{capture.dtype}, + .device_id = capture.device_id, + .pointer = capture.pointer, + .readonly = (capture.flags & CUVS_UDF_CAPTURE_READONLY) != 0}; + + auto const ndim = static_cast(capture.ndim); + if (ndim > 0) { next.shape.assign(capture.shape, capture.shape + ndim); } + if (capture.strides != nullptr) { + next.strides.assign(capture.strides, capture.strides + ndim); + } + out.captures.push_back(std::move(next)); + } + + return out; +} + +} // namespace cuvs::jit diff --git a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp index 35aa46633c..a9f96c62aa 100644 --- a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp +++ b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp @@ -46,15 +46,17 @@ struct StaticFatbinFragmentEntry final : FatbinFragmentEntry { static const size_t length; }; -struct UDFFatbinFragment final : FatbinFragmentEntry { +struct UDFFatbinFragment final : FragmentEntry { UDFFatbinFragment(std::string key, std::vector bytes) : key_(std::move(key)), bytes_(std::move(bytes)) { } - const uint8_t* get_data() const override { return bytes_.data(); } + const uint8_t* get_data() const { return bytes_.data(); } - size_t get_length() const override { return bytes_.size(); } + size_t get_length() const { return bytes_.size(); } + + bool add_to(nvJitLinkHandle& handle) const override; const char* get_key() const override { return key_.c_str(); } diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index a8fb4f49be..ff7f073e38 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -7,10 +7,12 @@ #include "common.hpp" #include +#include #include #include #include #include +#include #include #include @@ -76,8 +78,10 @@ struct index_params : cuvs::neighbors::index_params { struct search_params : cuvs::neighbors::search_params { /** The number of clusters to search. */ uint32_t n_probes = 20; - /** Custom metric UDF code. */ + /** Custom metric UDF CUDA/C++ source code. */ std::optional metric_udf = std::nullopt; + /** Custom metric UDF LTO-IR artifact. */ + std::optional metric_ltoir_udf = std::nullopt; }; static_assert(std::is_aggregate_v); @@ -3481,8 +3485,15 @@ inline std::string instantiate_udf(char const* data_type, char const* acc_type, << " compute_dist_udf_impl<" << data_type << ", " << acc_type << ", " << veclen << ">(acc, x, y);\n" << "}\n" + << "template \n" + << "__device__ void compute_dist(AccT& acc, AccT x, AccT y, unsigned int, const void*) {\n" + << " compute_dist_udf_impl<" << data_type << ", " << acc_type << ", " << veclen + << ">(acc, x, y);\n" + << "}\n" << "template __device__ void compute_dist<" << acc_type << ">(" << acc_type << "&, " << acc_type << ", " << acc_type << ");\n" + << "template __device__ void compute_dist<" << acc_type << ">(" << acc_type << "&, " + << acc_type << ", " << acc_type << ", unsigned int, const void*);\n" << "}\n"; return oss.str(); } diff --git a/cpp/src/detail/jit_lto/FragmentEntry.cpp b/cpp/src/detail/jit_lto/FragmentEntry.cpp index bf0893c8a6..116b259875 100644 --- a/cpp/src/detail/jit_lto/FragmentEntry.cpp +++ b/cpp/src/detail/jit_lto/FragmentEntry.cpp @@ -12,3 +12,11 @@ bool FatbinFragmentEntry::add_to(nvJitLinkHandle& handle) const check_nvjitlink_result(handle, result); return true; } + +bool UDFFatbinFragment::add_to(nvJitLinkHandle& handle) const +{ + auto result = nvJitLinkAddData(handle, NVJITLINK_INPUT_LTOIR, get_data(), get_length(), get_key()); + + check_nvjitlink_result(handle, result); + return true; +} diff --git a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/device_functions.cuh b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/device_functions.cuh index 0a95d392f6..8855748885 100644 --- a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/device_functions.cuh +++ b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/device_functions.cuh @@ -18,11 +18,15 @@ __device__ float load_and_compute_dist(AccT& dist, const T* query, T* query_shared, const uint32_t dim, - const uint32_t query_smem_elems); + const uint32_t query_smem_elems, + const void* metric_capture_0); template __device__ void compute_dist(AccT& acc, AccT x, AccT y); +template +__device__ void compute_dist(AccT& acc, AccT x, AccT y, uint32_t dim, const void* metric_capture_0); + template __device__ T post_process(T val); diff --git a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_impl.cuh b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_impl.cuh index 66a13a6cba..c2e2dc2140 100644 --- a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_impl.cuh +++ b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_impl.cuh @@ -79,6 +79,7 @@ __device__ __forceinline__ void interleaved_scan_impl(const uint32_t query_smem_ const uint32_t max_samples, const uint32_t* chunk_indices, const uint32_t dim, + const void* metric_capture_0, IdxT* const* const inds_ptrs, uint32_t* bitset_ptr, IdxT bitset_len, @@ -175,7 +176,8 @@ __device__ __forceinline__ void interleaved_scan_impl(const uint32_t query_smem_ query, query_shared, dim, - query_smem_elems); + query_smem_elems, + metric_capture_0); } if constexpr (kManageLocalTopK) { diff --git a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_kernel.cu.in b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_kernel.cu.in index 8214dfdb11..ac1c831406 100644 --- a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_kernel.cu.in +++ b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_kernel.cu.in @@ -30,6 +30,7 @@ extern "C" __global__ __launch_bounds__(kThreadsPerBlock) void interleaved_scan( const uint32_t max_samples, const uint32_t* chunk_indices, const uint32_t dim, + const void* metric_capture_0, index_t* const* const inds_ptrs, uint32_t* bitset_ptr, index_t bitset_len, @@ -48,6 +49,7 @@ extern "C" __global__ __launch_bounds__(kThreadsPerBlock) void interleaved_scan( max_samples, chunk_indices, dim, + metric_capture_0, inds_ptrs, bitset_ptr, bitset_len, diff --git a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/kernel_def.hpp b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/kernel_def.hpp index 7a73d001ec..f7db223b53 100644 --- a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/kernel_def.hpp +++ b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/kernel_def.hpp @@ -21,6 +21,7 @@ using interleaved_scan_func_t = void(const uint32_t query_smem_elems, const uint32_t max_samples, const uint32_t* chunk_indices, const uint32_t dim, + const void* metric_capture_0, IdxT* const* const inds_ptrs, uint32_t* bitset_ptr, IdxT bitset_len, diff --git a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/load_and_compute_dist_impl.cuh b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/load_and_compute_dist_impl.cuh index ccb6ef7e54..2835e80362 100644 --- a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/load_and_compute_dist_impl.cuh +++ b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/load_and_compute_dist_impl.cuh @@ -30,9 +30,11 @@ struct loadAndComputeDist { AccT& dist; AccT& norm_query; AccT& norm_data; + const void* metric_capture_0; - __device__ __forceinline__ loadAndComputeDist(AccT& dist, AccT& norm_query, AccT& norm_data) - : dist(dist), norm_query(norm_query), norm_data(norm_data) + __device__ __forceinline__ loadAndComputeDist( + AccT& dist, AccT& norm_query, AccT& norm_data, const void* metric_capture_0 = nullptr) + : dist(dist), norm_query(norm_query), norm_data(norm_data), metric_capture_0(metric_capture_0) { } @@ -55,7 +57,7 @@ struct loadAndComputeDist { raft::lds(queryRegs, &query_shared[shmemIndex + j * Veclen]); #pragma unroll for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, queryRegs[k], encV[k]); + compute_dist(dist, queryRegs[k], encV[k], shmemIndex + j * Veclen + k, metric_capture_0); if constexpr (ComputeNorm) { norm_query += queryRegs[k] * queryRegs[k]; norm_data += encV[k] * encV[k]; @@ -90,7 +92,7 @@ struct loadAndComputeDist { #pragma unroll for (int k = 0; k < Veclen; ++k) { T q = raft::shfl(queryReg, d + k, raft::WarpSize); - compute_dist(dist, q, encV[k]); + compute_dist(dist, q, encV[k], baseLoadIndex + d + k, metric_capture_0); if constexpr (ComputeNorm) { norm_query += q * q; norm_data += encV[k] * encV[k]; @@ -116,7 +118,7 @@ struct loadAndComputeDist { #pragma unroll for (int k = 0; k < Veclen; k++) { T q = raft::shfl(queryReg, d + k, raft::WarpSize); - compute_dist(dist, q, enc[k]); + compute_dist(dist, q, enc[k], dimBlocks + d + k, metric_capture_0); if constexpr (ComputeNorm) { norm_query += q * q; norm_data += enc[k] * enc[k]; @@ -135,7 +137,8 @@ struct loadAndComputeDist __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, uint32_t& norm_query, - uint32_t& norm_data) + uint32_t& norm_data, + const void* = nullptr) : dist(dist), norm_query(norm_query), norm_data(norm_data) { } @@ -234,7 +237,8 @@ struct loadAndComputeDist { __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, uint32_t& norm_query, - uint32_t& norm_data) + uint32_t& norm_data, + const void* = nullptr) : dist(dist), norm_query(norm_query), norm_data(norm_data) { } @@ -309,7 +313,8 @@ struct loadAndComputeDist { __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, uint32_t& norm_query, - uint32_t& norm_data) + uint32_t& norm_data, + const void* = nullptr) : dist(dist), norm_query(norm_query), norm_data(norm_data) { } @@ -385,7 +390,8 @@ struct loadAndComputeDist { __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, uint32_t& norm_query, - uint32_t& norm_data) + uint32_t& norm_data, + const void* = nullptr) : dist(dist), norm_query(norm_query), norm_data(norm_data) { } @@ -461,7 +467,8 @@ struct loadAndComputeDist { __device__ __forceinline__ loadAndComputeDist(int32_t& dist, int32_t& norm_query, - int32_t& norm_data) + int32_t& norm_data, + const void* = nullptr) : dist(dist), norm_query(norm_query), norm_data(norm_data) { } @@ -555,7 +562,8 @@ struct loadAndComputeDist { int32_t& norm_data; __device__ __forceinline__ loadAndComputeDist(int32_t& dist, int32_t& norm_query, - int32_t& norm_data) + int32_t& norm_data, + const void* = nullptr) : dist(dist), norm_query(norm_query), norm_data(norm_data) { } @@ -626,7 +634,8 @@ struct loadAndComputeDist { int32_t& norm_data; __device__ __forceinline__ loadAndComputeDist(int32_t& dist, int32_t& norm_query, - int32_t& norm_data) + int32_t& norm_data, + const void* = nullptr) : dist(dist), norm_query(norm_query), norm_data(norm_data) { } @@ -696,7 +705,8 @@ __device__ float load_and_compute_dist_impl(AccT& dist, const T* query, T* query_shared, const uint32_t dim, - const uint32_t query_smem_elems) + const uint32_t query_smem_elems, + const void* metric_capture_0) { using align_warp = raft::Pow2; constexpr int kUnroll = raft::WarpSize / Veclen; @@ -707,7 +717,7 @@ __device__ float load_and_compute_dist_impl(AccT& dist, const uint32_t full_warps_along_dim = align_warp::roundDown(dim); // Process first shm_assisted_dim dimensions (always using shared memory) - loadAndComputeDist lc(dist, norm_query, norm_dataset); + loadAndComputeDist lc(dist, norm_query, norm_dataset, metric_capture_0); for (int pos = 0; pos < shm_assisted_dim; pos += raft::WarpSize, data += kIndexGroupSize * raft::WarpSize) { lc.runLoadShmemCompute(data, query_shared, lane_id, pos); @@ -715,14 +725,14 @@ __device__ float load_and_compute_dist_impl(AccT& dist, if (dim > query_smem_elems) { // The default path - using shfl ops - for dimensions beyond query_smem_elems - loadAndComputeDist lc(dist, norm_query, norm_dataset); + loadAndComputeDist lc(dist, norm_query, norm_dataset, metric_capture_0); for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += raft::WarpSize) { lc.runLoadShflAndCompute(data, query, pos, lane_id); } lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); } else { // when shm_assisted_dim == full_warps_along_dim < dim - loadAndComputeDist<1, Veclen, T, AccT, ComputeNorm> lc(dist, norm_query, norm_dataset); + loadAndComputeDist<1, Veclen, T, AccT, ComputeNorm> lc(dist, norm_query, norm_dataset, metric_capture_0); for (int pos = full_warps_along_dim; pos < dim; pos += Veclen, data += kIndexGroupSize * Veclen) { lc.runLoadShmemCompute(data, query_shared, lane_id, pos); diff --git a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/load_and_compute_dist_kernel.cu.in b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/load_and_compute_dist_kernel.cu.in index 779705e53c..ddcfe2f08e 100644 --- a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/load_and_compute_dist_kernel.cu.in +++ b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/load_and_compute_dist_kernel.cu.in @@ -26,7 +26,8 @@ __device__ float load_and_compute_dist(acc_t& dist, const data_t* query, data_t* query_shared, const uint32_t dim, - const uint32_t query_smem_elems) + const uint32_t query_smem_elems, + const void* metric_capture_0) { return load_and_compute_dist_impl(dist, norm_query, @@ -36,7 +37,8 @@ __device__ float load_and_compute_dist(acc_t& dist, query, query_shared, dim, - query_smem_elems); + query_smem_elems, + metric_capture_0); } } // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/metric_kernel.cu.in b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/metric_kernel.cu.in index eb3e4e0328..63dd0d6c3f 100644 --- a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/metric_kernel.cu.in +++ b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/metric_kernel.cu.in @@ -25,4 +25,14 @@ __device__ void compute_dist(acc_t& acc, acc_t x, acc_t y) compute_dist_impl(acc, x, y); } +template <> +__device__ void compute_dist(acc_t& acc, + acc_t x, + acc_t y, + uint32_t, + const void*) +{ + compute_dist_impl(acc, x, y); +} + } // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh index 052b7bfe9a..c2b0929dcc 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh @@ -8,6 +8,7 @@ #include "../detail/ann_utils.cuh" #include "ivf_flat_interleaved_scan_jit.cuh" #include +#include #include #include #include @@ -37,6 +38,8 @@ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream, \ - const std::optional& metric_udf); + const std::optional& metric_udf, \ + const std::optional& \ + metric_ltoir_udf); #define COMMA , diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh index 3a782822b4..a99731e1be 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh @@ -6,6 +6,7 @@ #pragma once #include +#include #include #include #include @@ -36,7 +37,9 @@ void ivfflat_interleaved_scan(const index& index, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream, - const std::optional& metric_udf) RAFT_EXPLICIT; + const std::optional& metric_udf, + const std::optional& metric_ltoir_udf) + RAFT_EXPLICIT; #define CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(T, IdxT, SampleFilterT) \ extern template void \ @@ -59,7 +62,9 @@ void ivfflat_interleaved_scan(const index& index, float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream, \ - const std::optional& metric_udf); + const std::optional& metric_udf, \ + const std::optional& \ + metric_ltoir_udf); CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float, int64_t, cuvs::neighbors::filtering::none_sample_filter); diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh index a9748ef836..00dc5415e0 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh @@ -9,12 +9,14 @@ #include "detail/jit_lto_kernels/interleaved_scan_planner.hpp" #include "detail/jit_lto_kernels/kernel_def.hpp" #include +#include #include #include #include #include #include #include +#include #include #include @@ -61,6 +63,56 @@ constexpr auto get_idx_type_tag() template inline constexpr bool type_name_always_false_v = false; +inline const char* capture_pointer_type(std::string const& dtype) +{ + if (dtype == "float32") { return "const float*"; } + if (dtype == "int32") { return "const int*"; } + if (dtype == "int64") { return "const long long*"; } + RAFT_FAIL("Unsupported IVF Flat metric UDF capture dtype: %s", dtype.c_str()); + return ""; +} + +inline std::string instantiate_ltoir_udf_wrapper(cuvs::jit::ltoir_udf const& udf, + char const* acc_type) +{ + RAFT_EXPECTS(udf.captures.size() <= 1, + "IVF Flat metric LTO-IR UDF currently supports at most one capture"); + + auto const symbol_name = udf.symbol_name.c_str(); + std::ostringstream oss; + if (udf.captures.empty()) { + oss << "\nextern \"C\" __device__ " << acc_type << " " << symbol_name << "(" << acc_type + << " x, " << acc_type << " y, " << acc_type << " acc);\n"; + } else { + auto const* capture_type = capture_pointer_type(udf.captures[0].dtype); + oss << "\nextern \"C\" __device__ " << acc_type << " " << symbol_name << "(" << acc_type + << " x, " << acc_type << " y, " << acc_type << " acc, " << capture_type + << " capture_0, long long dim);\n"; + } + + oss << "namespace cuvs::neighbors::ivf_flat::detail {\n" + << "template \n" + << "__device__ void compute_dist(AccT& acc, AccT x, AccT y, unsigned int dim, const void* capture_0) {\n"; + if (udf.captures.empty()) { + oss << " acc = " << symbol_name << "(x, y, acc);\n"; + } else { + auto const* capture_type = capture_pointer_type(udf.captures[0].dtype); + oss << " acc = " << symbol_name << "(x, y, acc, static_cast<" << capture_type + << ">(capture_0), static_cast(dim));\n"; + } + oss << "}\n" + << "template \n" + << "__device__ void compute_dist(AccT& acc, AccT x, AccT y) {\n" + << " compute_dist(acc, x, y, 0u, nullptr);\n" + << "}\n" + << "template __device__ void compute_dist<" << acc_type << ">(" << acc_type << "&, " + << acc_type << ", " << acc_type << ");\n" + << "template __device__ void compute_dist<" << acc_type << ">(" << acc_type << "&, " + << acc_type << ", " << acc_type << ", unsigned int, const void*);\n" + << "}\n"; + return oss.str(); +} + template constexpr const char* type_name() { @@ -152,7 +204,8 @@ void launch_kernel(const index& index, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream, - const std::optional& metric_udf) + const std::optional& metric_udf, + const std::optional& metric_ltoir_udf) { RAFT_EXPECTS(Veclen == index.veclen(), "Configured Veclen does not match the index interleaving pattern."); @@ -164,13 +217,43 @@ void launch_kernel(const index& index, InterleavedScanPlanner kernel_planner; kernel_planner.add_entrypoint(); + const void* metric_capture_0 = nullptr; + if constexpr (std::is_same_v) { - RAFT_EXPECTS(metric_udf.has_value(), "CustomUDF search requires metric_udf"); - std::string metric_udf_code = metric_udf.value(); - metric_udf_code += - experimental::udf::instantiate_udf(type_name(), type_name(), Veclen); - auto udf_fragment = nvrtc_compiler().compile(metric_udf_code, metric_udf_code); - kernel_planner.add_metric_udf_fragment(std::move(udf_fragment)); + RAFT_EXPECTS(metric_udf.has_value() != metric_ltoir_udf.has_value(), + "CustomUDF search requires exactly one of metric_udf or metric_ltoir_udf"); + if (metric_ltoir_udf.has_value()) { + auto const& udf = metric_ltoir_udf.value(); + RAFT_EXPECTS(udf.abi == "rapids.cuvs.ivf_flat.metric.v1", + "Unsupported IVF Flat metric LTO-IR UDF ABI: %s", + udf.abi.c_str()); + RAFT_EXPECTS(udf.payload_kind == cuvs::jit::device_udf_payload_kind::ltoir, + "IVF Flat metric_ltoir_udf requires an LTO-IR payload"); + RAFT_EXPECTS(!udf.payload.empty(), "metric_ltoir_udf payload must not be empty"); + RAFT_EXPECTS(!udf.symbol_name.empty(), "metric_ltoir_udf symbol_name must not be empty"); + RAFT_EXPECTS(!udf.cache_key.empty(), "metric_ltoir_udf cache_key must not be empty"); + RAFT_EXPECTS(udf.captures.size() <= 1, + "metric_ltoir_udf currently supports at most one capture"); + if (!udf.captures.empty()) { + RAFT_EXPECTS(udf.captures[0].pointer != 0, + "metric_ltoir_udf capture pointer must not be zero"); + metric_capture_0 = reinterpret_cast(udf.captures[0].pointer); + } + + auto payload_fragment = std::make_unique(udf.cache_key, udf.payload); + kernel_planner.add_metric_udf_fragment(std::move(payload_fragment)); + + auto wrapper_code = instantiate_ltoir_udf_wrapper(udf, type_name()); + auto wrapper_key = udf.cache_key + ":ivf_flat_metric_wrapper:" + std::to_string(Veclen); + auto wrapper_fragment = nvrtc_compiler().compile(wrapper_key, wrapper_code); + kernel_planner.add_metric_udf_fragment(std::move(wrapper_fragment)); + } else { + std::string metric_udf_code = metric_udf.value(); + metric_udf_code += + experimental::udf::instantiate_udf(type_name(), type_name(), Veclen); + auto udf_fragment = nvrtc_compiler().compile(metric_udf_code, metric_udf_code); + kernel_planner.add_metric_udf_fragment(std::move(udf_fragment)); + } } else { kernel_planner.add_metric_device_function(); } @@ -233,6 +316,7 @@ void launch_kernel(const index& index, max_samples, chunk_indices, index.dim(), + metric_capture_0, // sample_filter, inds_ptrs, bitset_ptr.value_or(nullptr), @@ -436,7 +520,8 @@ void ivfflat_interleaved_scan(const index& index, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream, - const std::optional& metric_udf) + const std::optional& metric_udf, + const std::optional& metric_ltoir_udf) { const uint32_t n_probes_clamped = std::min(n_probes, index.n_lists()); const int capacity = raft::bound_by_power_of_two(k); @@ -473,7 +558,8 @@ void ivfflat_interleaved_scan(const index& index, distances, grid_dim_x, stream, - metric_udf); + metric_udf, + metric_ltoir_udf); } } // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index 960d48c818..fe7df59443 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -209,7 +209,8 @@ void search_impl(raft::resources const& handle, nullptr, grid_dim_x, stream, - params.metric_udf); + params.metric_udf, + params.metric_ltoir_udf); } else { grid_dim_x = 1; } @@ -265,7 +266,8 @@ void search_impl(raft::resources const& handle, distances_dev_ptr, grid_dim_x, stream, - params.metric_udf); + params.metric_udf, + params.metric_ltoir_udf); RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); if (indices_dev_ptr != nullptr) { RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); } @@ -393,8 +395,13 @@ void search_with_filtering(raft::resources const& handle, RAFT_EXPECTS(queries.extent(1) == index.dim(), "Number of query dimensions should equal number of dimensions in the index."); + RAFT_EXPECTS(!(params.metric_udf.has_value() && params.metric_ltoir_udf.has_value()), + "Only one of search_params.metric_udf and search_params.metric_ltoir_udf may be set"); + auto effective_metric = - params.metric_udf.has_value() ? cuvs::distance::DistanceType::CustomUDF : index.metric(); + (params.metric_udf.has_value() || params.metric_ltoir_udf.has_value()) + ? cuvs::distance::DistanceType::CustomUDF + : index.metric(); search_with_filtering(handle, params, diff --git a/cpp/src/neighbors/refine/refine_device.cuh b/cpp/src/neighbors/refine/refine_device.cuh index e027dca53b..adecdef02c 100644 --- a/cpp/src/neighbors/refine/refine_device.cuh +++ b/cpp/src/neighbors/refine/refine_device.cuh @@ -119,6 +119,7 @@ void refine_device( distances.data_handle(), grid_dim_x, raft::resource::get_cuda_stream(handle), + std::nullopt, std::nullopt); // postprocessing -- neighbors from position to actual id diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 48a444fa18..f01efcef64 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -139,6 +139,7 @@ ConfigureTest( GPUS 1 PERCENT 100 ) +target_link_libraries(NEIGHBORS_ANN_IVF_FLAT_UDF_TEST PRIVATE CUDA::nvrtc) ConfigureTest( NAME NEIGHBORS_ANN_IVF_PQ_TEST diff --git a/cpp/tests/neighbors/ann_ivf_flat/test_udf.cu b/cpp/tests/neighbors/ann_ivf_flat/test_udf.cu index ecd0c0fe0e..7e879baeb4 100644 --- a/cpp/tests/neighbors/ann_ivf_flat/test_udf.cu +++ b/cpp/tests/neighbors/ann_ivf_flat/test_udf.cu @@ -5,7 +5,9 @@ #include +#include #include +#include #include #include #include @@ -18,6 +20,8 @@ #include #include #include +#include +#include #include namespace cuvs::neighbors::ivf_flat { @@ -36,6 +40,65 @@ CUVS_METRIC(chebyshev_linf, { acc = (d > acc) ? d : acc; }) +namespace { + +void check_nvrtc(nvrtcResult result) +{ + RAFT_EXPECTS(result == NVRTC_SUCCESS, "nvrtc error: %s", nvrtcGetErrorString(result)); +} + +std::vector compile_ltoir_with_nvrtc(std::string const& source) +{ + int device = 0; + int major = 0; + int minor = 0; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + auto arch = std::string{"-arch=sm_"} + std::to_string(major * 10 + minor); + + nvrtcProgram prog; + check_nvrtc(nvrtcCreateProgram(&prog, source.c_str(), "ltoir_metric_udf_test.cu", 0, nullptr, nullptr)); + + std::array opts{arch.c_str(), "-dlto", "-rdc=true", "--std=c++20", "-default-device"}; + auto compile_result = nvrtcCompileProgram(prog, static_cast(opts.size()), opts.data()); + if (compile_result != NVRTC_SUCCESS) { + size_t log_size = 0; + check_nvrtc(nvrtcGetProgramLogSize(prog, &log_size)); + std::string log(log_size, '\0'); + check_nvrtc(nvrtcGetProgramLog(prog, log.data())); + nvrtcDestroyProgram(&prog); + RAFT_FAIL("nvrtc compile error log:\n%s", log.c_str()); + } + + size_t ltoir_size = 0; + check_nvrtc(nvrtcGetLTOIRSize(prog, <oir_size)); + std::vector ltoir(ltoir_size); + check_nvrtc(nvrtcGetLTOIR(prog, reinterpret_cast(ltoir.data()))); + check_nvrtc(nvrtcDestroyProgram(&prog)); + return ltoir; +} + +cuvs::jit::ltoir_udf make_ltoir_l2_metric_udf() +{ + auto constexpr symbol_name = "cuvs_test_ltoir_l2_update_f32"; + std::string source = R"( +extern "C" __device__ float cuvs_test_ltoir_l2_update_f32(float x, float y, float acc) +{ + float d = x - y; + return acc + d * d; +} +)"; + + auto key = std::string{"ivf_flat_ltoir_metric_udf_test:"} + symbol_name; + return cuvs::jit::ltoir_udf{.abi = "rapids.cuvs.ivf_flat.metric.v1", + .payload = compile_ltoir_with_nvrtc(source), + .symbol_name = symbol_name, + .cache_key = key}; +} + +} // namespace + // ============================================================================ // Test data traits for different types // ============================================================================ @@ -388,6 +451,96 @@ TYPED_TEST(IvfFlatUdfTest, CustomL2MatchesBuiltIn) 1.0)); } + +TEST(IvfFlatLtoirUdf, L2ArtifactMatchesBuiltIn) +{ + using T = float; + using Traits = TestDataTraits; + + raft::resources handle; + auto stream = raft::resource::get_cuda_stream(handle); + + auto database = Traits::database(); + auto queries = Traits::queries(); + int64_t const num_db_vecs = Traits::num_db_vecs; + int64_t const num_queries = 2; + int64_t const dim = Traits::dim; + int64_t const k = 4; + uint32_t const n_lists = 2; + uint32_t const n_probes = 2; + + rmm::device_uvector d_database(num_db_vecs * dim, stream); + rmm::device_uvector d_queries(num_queries * dim, stream); + raft::copy(d_database.data(), database.data(), database.size(), stream); + raft::copy(d_queries.data(), queries.data(), queries.size(), stream); + + auto database_view = raft::make_device_matrix_view(d_database.data(), + num_db_vecs, + dim); + auto queries_view = raft::make_device_matrix_view(d_queries.data(), + num_queries, + dim); + + ivf_flat::index_params index_params; + index_params.n_lists = n_lists; + index_params.metric = cuvs::distance::DistanceType::L2Expanded; + auto idx = ivf_flat::build(handle, index_params, database_view); + + rmm::device_uvector d_indices_builtin(num_queries * k, stream); + rmm::device_uvector d_distances_builtin(num_queries * k, stream); + rmm::device_uvector d_indices_ltoir(num_queries * k, stream); + rmm::device_uvector d_distances_ltoir(num_queries * k, stream); + + auto indices_builtin_view = + raft::make_device_matrix_view(d_indices_builtin.data(), num_queries, k); + auto distances_builtin_view = + raft::make_device_matrix_view(d_distances_builtin.data(), num_queries, k); + auto indices_ltoir_view = + raft::make_device_matrix_view(d_indices_ltoir.data(), num_queries, k); + auto distances_ltoir_view = + raft::make_device_matrix_view(d_distances_ltoir.data(), num_queries, k); + + ivf_flat::search_params builtin_params; + builtin_params.n_probes = n_probes; + ivf_flat::search(handle, + builtin_params, + idx, + queries_view, + indices_builtin_view, + distances_builtin_view); + + ivf_flat::search_params ltoir_params; + ltoir_params.n_probes = n_probes; + ltoir_params.metric_ltoir_udf = make_ltoir_l2_metric_udf(); + ivf_flat::search(handle, + ltoir_params, + idx, + queries_view, + indices_ltoir_view, + distances_ltoir_view); + + std::vector h_indices_builtin(num_queries * k); + std::vector h_distances_builtin(num_queries * k); + std::vector h_indices_ltoir(num_queries * k); + std::vector h_distances_ltoir(num_queries * k); + + raft::copy(h_indices_builtin.data(), d_indices_builtin.data(), h_indices_builtin.size(), stream); + raft::copy( + h_distances_builtin.data(), d_distances_builtin.data(), h_distances_builtin.size(), stream); + raft::copy(h_indices_ltoir.data(), d_indices_ltoir.data(), h_indices_ltoir.size(), stream); + raft::copy(h_distances_ltoir.data(), d_distances_ltoir.data(), h_distances_ltoir.size(), stream); + raft::resource::sync_stream(handle); + + ASSERT_TRUE(eval_neighbours(h_indices_builtin, + h_indices_ltoir, + h_distances_builtin, + h_distances_ltoir, + static_cast(num_queries), + static_cast(k), + 1e-5, + 1.0)); +} + /** * Build the index with native L2, search with a different metric (Chebyshev UDF), and compare to * exhaustive top-k from naive_knn (DistanceType::Linf). With n_probes == n_lists every cluster is @@ -468,6 +621,160 @@ TEST(IvfFlatUdfChebyshev, ChebyshevMatchesNaiveKnnWhenProbingAllLists) min_recall)); } + +TEST(DeviceUDFDescriptor, CopiesValidDescriptorMetadata) +{ + std::array payload{1, 2, 3, 4}; + std::array shape{8}; + std::array strides{4}; + auto capture = cuvsUDFCapture{.name = "weights", + .dtype = "float32", + .shape = shape.data(), + .strides = strides.data(), + .ndim = 1, + .device_id = 0, + .pointer = 0x1234, + .flags = CUVS_UDF_CAPTURE_READONLY}; + auto desc = cuvsDeviceUDF{.abi = "rapids.cuvs.ivf_flat.metric.v1", + .payload_kind = CUVS_DEVICE_UDF_PAYLOAD_LTOIR, + .payload = payload.data(), + .payload_size = payload.size(), + .symbol_name = "cuvs_test_symbol", + .captures = &capture, + .n_captures = 1, + .cache_key = "descriptor-test-key", + .flags = 0}; + + auto udf = cuvs::jit::make_device_udf(desc); + + EXPECT_EQ(udf.abi, desc.abi); + EXPECT_EQ(udf.payload_kind, cuvs::jit::device_udf_payload_kind::ltoir); + EXPECT_EQ(udf.payload, std::vector(payload.begin(), payload.end())); + ASSERT_EQ(udf.captures.size(), 1); + EXPECT_EQ(udf.captures[0].name, "weights"); + EXPECT_EQ(udf.captures[0].dtype, "float32"); + EXPECT_EQ(udf.captures[0].shape, std::vector({8})); + EXPECT_EQ(udf.captures[0].strides, std::vector({4})); + EXPECT_EQ(udf.captures[0].pointer, 0x1234); + EXPECT_TRUE(udf.captures[0].readonly); +} + +TEST(DeviceUDFDescriptor, RejectsMissingRequiredFields) +{ + std::array payload{1, 2, 3, 4}; + auto desc = cuvsDeviceUDF{.abi = "rapids.cuvs.ivf_flat.metric.v1", + .payload_kind = CUVS_DEVICE_UDF_PAYLOAD_LTOIR, + .payload = payload.data(), + .payload_size = payload.size(), + .symbol_name = "cuvs_test_symbol", + .captures = nullptr, + .n_captures = 0, + .cache_key = "descriptor-test-key", + .flags = 0}; + + auto invalid = desc; + invalid.abi = nullptr; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); + + invalid = desc; + invalid.payload = nullptr; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); + + invalid = desc; + invalid.payload_size = 0; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); + + invalid = desc; + invalid.symbol_name = nullptr; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); + + invalid = desc; + invalid.cache_key = nullptr; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); + + invalid = desc; + invalid.flags = 1; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); +} + +TEST(DeviceUDFDescriptor, RejectsMalformedCaptures) +{ + std::array payload{1, 2, 3, 4}; + std::array shape{8}; + auto capture = cuvsUDFCapture{.name = "weights", + .dtype = "float32", + .shape = shape.data(), + .strides = nullptr, + .ndim = 1, + .device_id = 0, + .pointer = 0x1234, + .flags = CUVS_UDF_CAPTURE_READONLY}; + auto desc = cuvsDeviceUDF{.abi = "rapids.cuvs.ivf_flat.metric.v1", + .payload_kind = CUVS_DEVICE_UDF_PAYLOAD_LTOIR, + .payload = payload.data(), + .payload_size = payload.size(), + .symbol_name = "cuvs_test_symbol", + .captures = &capture, + .n_captures = 1, + .cache_key = "descriptor-test-key", + .flags = 0}; + + auto invalid = desc; + invalid.captures = nullptr; + invalid.n_captures = 1; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); + + auto invalid_capture = capture; + invalid_capture.name = nullptr; + invalid = desc; + invalid.captures = &invalid_capture; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); + + invalid_capture = capture; + invalid_capture.dtype = nullptr; + invalid = desc; + invalid.captures = &invalid_capture; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); + + invalid_capture = capture; + invalid_capture.pointer = 0; + invalid = desc; + invalid.captures = &invalid_capture; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); + + invalid_capture = capture; + invalid_capture.shape = nullptr; + invalid = desc; + invalid.captures = &invalid_capture; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); + + invalid_capture = capture; + invalid_capture.flags = 2; + invalid = desc; + invalid.captures = &invalid_capture; + EXPECT_THROW(cuvs::jit::make_device_udf(invalid), raft::logic_error); +} + +TEST(DeviceUDFDescriptor, CopiesCudaSourcePayloadKind) +{ + std::string source = "extern \"C\" __device__ float f(float x, float y, float acc);"; + auto desc = cuvsDeviceUDF{.abi = "rapids.cuvs.ivf_flat.metric.v1", + .payload_kind = CUVS_DEVICE_UDF_PAYLOAD_CUDA_SOURCE, + .payload = source.data(), + .payload_size = source.size(), + .symbol_name = "f", + .captures = nullptr, + .n_captures = 0, + .cache_key = "descriptor-test-key", + .flags = 0}; + + auto udf = cuvs::jit::make_device_udf(desc); + + EXPECT_EQ(udf.payload_kind, cuvs::jit::device_udf_payload_kind::cuda_source); + EXPECT_EQ(std::string(reinterpret_cast(udf.payload.data()), udf.payload.size()), + source); +} + /** * Invalid UDF source must fail NVRTC compilation; search should surface that as an exception, * not return garbage neighbors. diff --git a/python/cuvs/cuvs/_lib/__init__.py b/python/cuvs/cuvs/_lib/__init__.py new file mode 100644 index 0000000000..8eca3cc68a --- /dev/null +++ b/python/cuvs/cuvs/_lib/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 diff --git a/python/cuvs/cuvs/_lib/device_udf.py b/python/cuvs/cuvs/_lib/device_udf.py new file mode 100644 index 0000000000..035afd9410 --- /dev/null +++ b/python/cuvs/cuvs/_lib/device_udf.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import dataclasses +import hashlib +import json +from dataclasses import dataclass, field +from typing import Any, Literal, Mapping + +PayloadKind = Literal["ltoir", "cuda_source"] + +_SUPPORTED_PAYLOAD_KINDS = frozenset({"ltoir", "cuda_source"}) +_CACHE_KEY_VERSION = "rapids.cuvs.device_udf.cache_key.v1" + + +@dataclass(frozen=True) +class UDFTarget: + sm: str + cuda_version: str + nvrtc_version: str | None + nvjitlink_version: str + numba_cuda_mlir_version: str | None + compile_options: tuple[str, ...] = () + + def cache_metadata(self) -> dict[str, Any]: + return dataclasses.asdict(self) + + +@dataclass(frozen=True) +class UDFCapture: + name: str + dtype: str + shape: tuple[int, ...] + strides: tuple[int, ...] | None + device_id: int + readonly: bool + pointer: int + owner: object = field(compare=False, hash=False, repr=False) + + def cache_metadata(self) -> dict[str, Any]: + # Deliberately exclude pointer and owner. Pointers vary run to run; owner + # only preserves lifetime and must not affect code identity. + return { + "name": self.name, + "dtype": self.dtype, + "shape": self.shape, + "strides": self.strides, + "device_id": self.device_id, + "readonly": self.readonly, + } + + def c_descriptor_metadata(self) -> dict[str, Any]: + return { + **self.cache_metadata(), + "pointer": self.pointer, + } + + +@dataclass(frozen=True) +class UDFArtifact: + abi: str + payload_kind: PayloadKind + payload: bytes | str + symbol_name: str + captures: tuple[UDFCapture, ...] + target: UDFTarget + cache_key: str + + def payload_bytes(self) -> bytes: + return _payload_bytes(self.payload) + + def c_descriptor_metadata(self) -> dict[str, Any]: + return { + "abi": self.abi, + "payload_kind": self.payload_kind, + "payload_size": len(self.payload_bytes()), + "symbol_name": self.symbol_name, + "captures": [cap.c_descriptor_metadata() for cap in self.captures], + "cache_key": self.cache_key, + } + + +def make_udf_artifact( + *, + abi: str, + payload_kind: PayloadKind, + payload: bytes | str, + symbol_name: str, + captures: tuple[UDFCapture, ...] = (), + target: UDFTarget, + source_hash: str, + lowering_version: str, + algorithm_options: Mapping[str, Any] | None = None, +) -> UDFArtifact: + cache_key = build_cache_key( + abi=abi, + payload_kind=payload_kind, + payload=payload, + target=target, + captures=captures, + source_hash=source_hash, + lowering_version=lowering_version, + algorithm_options=algorithm_options, + ) + return UDFArtifact( + abi=abi, + payload_kind=payload_kind, + payload=payload, + symbol_name=symbol_name, + captures=captures, + target=target, + cache_key=cache_key, + ) + + +def build_cache_key( + *, + abi: str, + payload_kind: PayloadKind, + payload: bytes | str, + target: UDFTarget, + captures: tuple[UDFCapture, ...] = (), + source_hash: str, + lowering_version: str, + algorithm_options: Mapping[str, Any] | None = None, +) -> str: + _validate_payload_kind(payload_kind) + key_material = { + "version": _CACHE_KEY_VERSION, + "abi": abi, + "payload_kind": payload_kind, + "payload_hash": _sha256_hex(_payload_bytes(payload)), + "target": target.cache_metadata(), + "captures": [capture.cache_metadata() for capture in captures], + "source_hash": source_hash, + "lowering_version": lowering_version, + "algorithm_options": dict(algorithm_options or {}), + } + return f"{_CACHE_KEY_VERSION}:{_sha256_hex(_stable_json_bytes(key_material))}" + + +def source_hash(source: str | bytes) -> str: + return _sha256_hex(_payload_bytes(source)) + + +def _validate_payload_kind(payload_kind: str) -> None: + if payload_kind not in _SUPPORTED_PAYLOAD_KINDS: + kinds = ", ".join(sorted(_SUPPORTED_PAYLOAD_KINDS)) + raise ValueError(f"payload_kind must be one of: {kinds}") + + +def _payload_bytes(payload: bytes | str) -> bytes: + if isinstance(payload, bytes): + return payload + if isinstance(payload, str): + return payload.encode("utf-8") + raise TypeError("payload must be bytes or str") + + +def _stable_json_bytes(value: Any) -> bytes: + return json.dumps(value, sort_keys=True, separators=(",", ":")).encode("utf-8") + + +def _sha256_hex(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() diff --git a/python/cuvs/cuvs/_lib/udf_backends/__init__.py b/python/cuvs/cuvs/_lib/udf_backends/__init__.py new file mode 100644 index 0000000000..8eca3cc68a --- /dev/null +++ b/python/cuvs/cuvs/_lib/udf_backends/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 diff --git a/python/cuvs/cuvs/_lib/udf_backends/numba_cuda_mlir.py b/python/cuvs/cuvs/_lib/udf_backends/numba_cuda_mlir.py new file mode 100644 index 0000000000..7fb7153291 --- /dev/null +++ b/python/cuvs/cuvs/_lib/udf_backends/numba_cuda_mlir.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import inspect +import textwrap +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from typing import Any + +from cuvs._lib.device_udf import ( + UDFArtifact, + UDFCapture, + UDFTarget, + make_udf_artifact, + source_hash, +) +from cuvs._lib.udf_validation import UDFCompilationError + + +@dataclass(frozen=True) +class NumbaCudaMLIRCompileResult: + artifact: UDFArtifact + return_type: Any + + +class NumbaCudaMLIRBackend: + payload_kind = "ltoir" + + def __init__(self, cuda_module: Any | None = None, types_module: Any | None = None): + if cuda_module is None or types_module is None: + from numba_cuda_mlir import cuda + from numba_cuda_mlir.numba_cuda import types + + cuda_module = cuda if cuda_module is None else cuda_module + types_module = types if types_module is None else types_module + + self.cuda = cuda_module + self.types = types_module + + def compile( + self, + fn: Any, + *, + abi: str, + symbol_name: str, + arg_types: Sequence[Any], + return_type: Any, + target: UDFTarget, + captures: tuple[UDFCapture, ...] = (), + lowering_version: str, + algorithm_options: Mapping[str, Any] | None = None, + forceinline: bool = True, + source: str | bytes | None = None, + ) -> NumbaCudaMLIRCompileResult: + sig = return_type(*arg_types) + try: + payload, actual_return_type = self.cuda.compile( + fn, + sig=sig, + device=True, + abi="c", + abi_info={"abi_name": symbol_name}, + output="ltoir", + forceinline=forceinline, + ) + except Exception as exc: # pragma: no cover - exercised in GPU env tests + raise UDFCompilationError(str(exc)) from exc + + if actual_return_type != return_type: + raise TypeError( + f"UDF return type must be {return_type}, got {actual_return_type}" + ) + + artifact = make_udf_artifact( + abi=abi, + payload_kind="ltoir", + payload=payload, + symbol_name=symbol_name, + captures=captures, + target=target, + source_hash=source_hash( + source if source is not None else _function_source(fn) + ), + lowering_version=lowering_version, + algorithm_options=algorithm_options, + ) + return NumbaCudaMLIRCompileResult( + artifact=artifact, + return_type=actual_return_type, + ) + + def float32(self) -> Any: + return self.types.float32 + + def int64(self) -> Any: + return self.types.int64 + + def float32_pointer(self) -> Any: + return self.types.CPointer(self.types.float32) + + +def current_target(cuda_module: Any, *, compile_options: Sequence[str] = ()) -> UDFTarget: + dev = cuda_module.get_current_device() + cc = getattr(dev, "compute_capability") + sm = f"{cc.major}{cc.minor}" + + cuda_version = _version_or_unknown(cuda_module) + try: + import numba_cuda_mlir + + numba_cuda_mlir_version = getattr(numba_cuda_mlir, "__version__", None) + except Exception: # pragma: no cover + numba_cuda_mlir_version = None + + return UDFTarget( + sm=sm, + cuda_version=cuda_version, + nvrtc_version=None, + nvjitlink_version="unknown", + numba_cuda_mlir_version=numba_cuda_mlir_version, + compile_options=tuple(compile_options), + ) + + +def _function_source(fn: Any) -> str: + try: + return textwrap.dedent(inspect.getsource(fn)) + except (OSError, TypeError): + return getattr(fn, "__qualname__", repr(fn)) + + +def _version_or_unknown(module: Any) -> str: + return str(getattr(module, "__version__", "unknown")) diff --git a/python/cuvs/cuvs/_lib/udf_validation.py b/python/cuvs/cuvs/_lib/udf_validation.py new file mode 100644 index 0000000000..85e73e7030 --- /dev/null +++ b/python/cuvs/cuvs/_lib/udf_validation.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import ast +import inspect +import textwrap +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any + + +class UnsupportedUDFSyntax(TypeError): + pass + + +class UnsupportedUDFCapture(TypeError): + pass + + +class UDFCompilationError(RuntimeError): + pass + + +class UDFABIError(RuntimeError): + pass + + +ALLOWED_CALLS = frozenset( + { + "abs", + "min", + "max", + "squared_diff", + "abs_diff", + "dot_product", + "range", + } +) +_ALLOWED_NODE_TYPES = ( + ast.Module, + ast.FunctionDef, + ast.arguments, + ast.arg, + ast.Load, + ast.Store, + ast.Return, + ast.Assign, + ast.AnnAssign, + ast.Expr, + ast.Name, + ast.Constant, + ast.BinOp, + ast.UnaryOp, + ast.BoolOp, + ast.Compare, + ast.If, + ast.IfExp, + ast.For, + ast.Call, + ast.Subscript, + ast.Attribute, + ast.Tuple, + ast.List, + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ast.FloorDiv, + ast.Mod, + ast.Pow, + ast.USub, + ast.UAdd, + ast.Not, + ast.And, + ast.Or, + ast.Eq, + ast.NotEq, + ast.Lt, + ast.LtE, + ast.Gt, + ast.GtE, +) + + +def validate_signature(fn: Any, expected: Iterable[str]) -> None: + expected_names = list(expected) + sig = inspect.signature(fn) + params = list(sig.parameters.values()) + names = [p.name for p in params] + + if names != expected_names: + raise TypeError(f"expected f({', '.join(expected_names)})") + + for param in params: + if param.kind not in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + raise TypeError("*args, **kwargs, and keyword-only args are not allowed") + + if param.default is not inspect.Parameter.empty: + raise TypeError("UDF parameters may not have defaults") + + +def validate_ivf_flat_metadata(order: str, initial: Any, coarse_metric: str) -> None: + if order not in {"min"}: + raise ValueError("v1 supports order='min'") + + if not isinstance(initial, (int, float)): + raise TypeError("initial must be a scalar number") + + if coarse_metric not in {"sqeuclidean", "inner_product"}: + raise ValueError("coarse_metric must be explicit and supported") + + +def validate_udf_policy(fn: Any) -> ast.FunctionDef: + source = textwrap.dedent(inspect.getsource(fn)) + tree = ast.parse(source) + function_defs = [node for node in tree.body if isinstance(node, ast.FunctionDef)] + if len(function_defs) != 1: + raise UnsupportedUDFSyntax("expected exactly one UDF function definition") + + for function_def in function_defs: + function_def.decorator_list = [] + + UDFPolicyValidator().visit(tree) + return function_defs[0] + + +class UDFPolicyValidator(ast.NodeVisitor): + def generic_visit(self, node: ast.AST) -> None: + if not isinstance(node, _ALLOWED_NODE_TYPES): + raise UnsupportedUDFSyntax(f"unsupported syntax: {type(node).__name__}") + super().generic_visit(node) + + def visit_Import(self, node: ast.Import) -> None: + raise UnsupportedUDFSyntax("imports are not allowed") + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + raise UnsupportedUDFSyntax("imports are not allowed") + + def visit_While(self, node: ast.While) -> None: + raise UnsupportedUDFSyntax("while loops are not allowed in v1") + + def visit_Lambda(self, node: ast.Lambda) -> None: + raise UnsupportedUDFSyntax("lambda expressions are not allowed") + + def visit_ListComp(self, node: ast.ListComp) -> None: + raise UnsupportedUDFSyntax("comprehensions are not allowed") + + def visit_SetComp(self, node: ast.SetComp) -> None: + raise UnsupportedUDFSyntax("comprehensions are not allowed") + + def visit_DictComp(self, node: ast.DictComp) -> None: + raise UnsupportedUDFSyntax("comprehensions are not allowed") + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> None: + raise UnsupportedUDFSyntax("comprehensions are not allowed") + + def visit_Call(self, node: ast.Call) -> None: + name = call_name(node) + if name not in ALLOWED_CALLS: + raise UnsupportedUDFSyntax(f"call to {name} is not allowed") + self.generic_visit(node) + + def visit_Assign(self, node: ast.Assign) -> None: + for target in node.targets: + if is_ctx_capture_write(target): + raise UnsupportedUDFSyntax("captures are read-only") + self.generic_visit(node) + + def visit_AugAssign(self, node: ast.AugAssign) -> None: + if is_ctx_capture_write(node.target): + raise UnsupportedUDFSyntax("captures are read-only") + self.generic_visit(node) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + if is_ctx_capture_write(node.target): + raise UnsupportedUDFSyntax("captures are read-only") + self.generic_visit(node) + + def visit_For(self, node: ast.For) -> None: + if not _is_range_call(node.iter): + raise UnsupportedUDFSyntax("for loops must iterate over range(...) in v1") + self.generic_visit(node) + + +def call_name(node: ast.Call) -> str: + if isinstance(node.func, ast.Name): + return node.func.id + if isinstance(node.func, ast.Attribute): + parts = [] + current: ast.AST = node.func + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + return ".".join(reversed(parts)) + return type(node.func).__name__ + + +def is_ctx_capture_write(node: ast.AST) -> bool: + if isinstance(node, ast.Attribute): + return isinstance(node.value, ast.Name) and node.value.id == "ctx" + if isinstance(node, ast.Subscript): + return _contains_ctx_attribute(node.value) + if isinstance(node, (ast.Tuple, ast.List)): + return any(is_ctx_capture_write(elt) for elt in node.elts) + return False + + +def _contains_ctx_attribute(node: ast.AST) -> bool: + if isinstance(node, ast.Attribute): + return isinstance(node.value, ast.Name) and node.value.id == "ctx" + if isinstance(node, ast.Subscript): + return _contains_ctx_attribute(node.value) + return False + + +def _is_range_call(node: ast.AST) -> bool: + return isinstance(node, ast.Call) and call_name(node) == "range" + + +@dataclass(frozen=True) +class CaptureInfo: + name: str + dtype: str + shape: tuple[int, ...] + strides: tuple[int, ...] | None + device_id: int + pointer: int + owner: object + readonly: bool = True + + +_SUPPORTED_CAPTURE_DTYPES = frozenset({"float32", "int32", "int64"}) +_TYPES_BY_CAI_TYPESTR = { + " CaptureInfo: + cai = getattr(value, "__cuda_array_interface__", None) + if cai is None: + raise UnsupportedUDFCapture(f"capture {name!r} must expose CUDA Array Interface") + + typestr = cai.get("typestr") + dtype = _TYPES_BY_CAI_TYPESTR.get(typestr) + if dtype not in _SUPPORTED_CAPTURE_DTYPES: + raise UnsupportedUDFCapture(f"unsupported capture dtype: {typestr}") + + data = cai.get("data") + if not isinstance(data, tuple) or len(data) < 1: + raise UnsupportedUDFCapture(f"capture {name!r} has invalid CUDA Array Interface data") + + pointer = int(data[0]) + # Most device-array providers, including CuPy, expose mutable arrays via + # CUDA Array Interface. cuVS treats captures as read-only by policy: UDF + # validation rejects writes, and generated wrappers pass const pointers. + readonly = True + + shape = tuple(int(dim) for dim in cai.get("shape", ())) + strides_raw = cai.get("strides") + strides = None if strides_raw is None else tuple(int(stride) for stride in strides_raw) + + device_id = _capture_device_id(cai) + if expected_device is not None and device_id != expected_device: + raise UnsupportedUDFCapture("capture device must match cuVS resource device") + + return CaptureInfo( + name=name, + dtype=dtype, + shape=shape, + strides=strides, + device_id=device_id, + pointer=pointer, + owner=value, + readonly=readonly, + ) + + +def _capture_device_id(cai: dict[str, Any]) -> int: + stream = cai.get("stream") + if isinstance(stream, tuple) and len(stream) >= 1 and isinstance(stream[0], int): + return stream[0] + return int(cai.get("device", 0)) diff --git a/python/cuvs/cuvs/neighbors/ivf_flat/__init__.py b/python/cuvs/cuvs/neighbors/ivf_flat/__init__.py index ccb6c24da3..b49d0fe959 100644 --- a/python/cuvs/cuvs/neighbors/ivf_flat/__init__.py +++ b/python/cuvs/cuvs/neighbors/ivf_flat/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 +from ._udf import cuda_source_metric, metric from .ivf_flat import ( Index, IndexParams, @@ -18,8 +19,10 @@ "IndexParams", "SearchParams", "build", + "cuda_source_metric", "extend", "load", + "metric", "save", "search", ] diff --git a/python/cuvs/cuvs/neighbors/ivf_flat/_udf.py b/python/cuvs/cuvs/neighbors/ivf_flat/_udf.py new file mode 100644 index 0000000000..479d2d4042 --- /dev/null +++ b/python/cuvs/cuvs/neighbors/ivf_flat/_udf.py @@ -0,0 +1,396 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import ast +import inspect +import keyword +import re +import textwrap +from collections.abc import Mapping +from typing import Any, Callable + +from cuvs._lib.device_udf import ( + UDFArtifact, + UDFCapture, + UDFTarget, + make_udf_artifact, + source_hash, +) +from cuvs._lib.udf_validation import ( + CaptureInfo, + UnsupportedUDFCapture, + UDFCompilationError, + validate_capture, + validate_ivf_flat_metadata, + validate_signature, + validate_udf_policy, +) + +_IVF_FLAT_METRIC_ABI = "rapids.cuvs.ivf_flat.metric.v1" +_LOWERING_VERSION = "ivf-flat-ctx-explicit-lowered-v1" +_CUDA_SOURCE_LOWERING_VERSION = "ivf-flat-cuda-source-v1" +_CAPTURE_NAME_RE = re.compile(r"^[A-Za-z_][0-9A-Za-z_]*$") + + +def cuda_source_metric( + source: str | bytes, + *, + order: str = "min", + initial: float = 0.0, + coarse_metric: str = "sqeuclidean", + symbol_name: str = "cuvs_py_ivf_flat_cuda_source_metric", + target: UDFTarget | None = None, +) -> UDFArtifact: + """Package an expert CUDA/C++ IVF-Flat metric source string. + + This is an explicit advanced path for source compatible with the + existing C++ ``search_params.metric_udf`` contract. cuVS appends the + IVF-Flat JIT adapter at search time, so the source must define + ``cuvs::neighbors::ivf_flat::detail::compute_dist_udf_impl`` with + the expected template signature. + """ + + validate_ivf_flat_metadata(order, initial, coarse_metric) + if initial != 0.0: + raise NotImplementedError( + "ivf_flat.cuda_source_metric currently requires initial=0.0" + ) + if not isinstance(source, (str, bytes)): + raise TypeError("source must be str or bytes") + if not _payload_bytes(source): + raise ValueError("source must not be empty") + if not symbol_name: + raise ValueError("symbol_name must not be empty") + + return make_udf_artifact( + abi=_IVF_FLAT_METRIC_ABI, + payload_kind="cuda_source", + payload=source, + symbol_name=symbol_name, + captures=(), + target=target or _runtime_cuda_source_target(), + source_hash=source_hash(source), + lowering_version=_CUDA_SOURCE_LOWERING_VERSION, + algorithm_options={ + "algorithm": "ivf_flat", + "coarse_metric": coarse_metric, + "capture_count": 0, + "initial": initial, + "order": order, + }, + ) + + +def metric( + fn: Callable[..., Any] | None = None, + *, + order: str = "min", + initial: float = 0.0, + coarse_metric: str = "sqeuclidean", + captures: Mapping[str, Any] | None = None, + symbol_name: str | None = None, + forceinline: bool = True, +) -> Callable[[Callable[..., Any]], UDFArtifact] | UDFArtifact: + """Compile an IVF-Flat metric UDF to a device artifact. + + The decorator accepts the user-facing ``f(x, y, acc, ctx)`` shape and + lowers it to the coordinate-wise IVF-Flat metric ABI. V1 supports no + captures or one contiguous ``float32`` CUDA-array capture, addressable as + ``ctx.[ctx.dim]``. + """ + + def decorate(user_fn: Callable[..., Any]) -> UDFArtifact: + return _compile_metric( + user_fn, + order=order, + initial=initial, + coarse_metric=coarse_metric, + captures=captures, + symbol_name=symbol_name, + forceinline=forceinline, + ) + + if fn is None: + return decorate + if not callable(fn): + raise TypeError("ivf_flat.metric must decorate a callable") + return decorate(fn) + + +def _compile_metric( + fn: Callable[..., Any], + *, + order: str, + initial: float, + coarse_metric: str, + captures: Mapping[str, Any] | None, + symbol_name: str | None, + forceinline: bool, +) -> UDFArtifact: + validate_ivf_flat_metadata(order, initial, coarse_metric) + if initial != 0.0: + raise NotImplementedError( + "ivf_flat.metric currently requires initial=0.0" + ) + + validate_signature(fn, ["x", "y", "acc", "ctx"]) + validate_udf_policy(fn) + capture_infos = _validate_captures(captures) + + source = _function_source(fn) + c_symbol = symbol_name or _default_symbol_name(fn, source) + lowered_fn, lowered_source = _lower_metric( + fn, c_symbol, source, capture_infos + ) + + from cuvs._lib.udf_backends.numba_cuda_mlir import ( + NumbaCudaMLIRBackend, + current_target, + ) + + backend = NumbaCudaMLIRBackend() + if not backend.cuda.is_available(): + raise UDFCompilationError("CUDA is not available to numba_cuda_mlir") + + _validate_capture_devices(capture_infos, _current_device_id(backend.cuda)) + udf_captures = tuple(_to_udf_capture(capture) for capture in capture_infos) + arg_types = [ + backend.float32(), + backend.float32(), + backend.float32(), + ] + for capture in capture_infos: + arg_types.append(_capture_arg_type(backend, capture)) + arg_types.append(backend.int64()) + + target = current_target(backend.cuda, compile_options=("-lto",)) + result = backend.compile( + lowered_fn, + abi=_IVF_FLAT_METRIC_ABI, + symbol_name=c_symbol, + arg_types=tuple(arg_types), + return_type=backend.float32(), + target=target, + captures=udf_captures, + lowering_version=_LOWERING_VERSION, + algorithm_options={ + "algorithm": "ivf_flat", + "coarse_metric": coarse_metric, + "capture_count": len(capture_infos), + "initial": initial, + "order": order, + }, + forceinline=forceinline, + source=lowered_source, + ) + return result.artifact + + +def _lower_metric( + fn: Callable[..., Any], + symbol_name: str, + source: str, + captures: tuple[CaptureInfo, ...], +) -> tuple[Callable[..., Any], str]: + tree = ast.parse(source) + function_defs = [ + node for node in tree.body if isinstance(node, ast.FunctionDef) + ] + if len(function_defs) != 1: + raise TypeError( + "expected exactly one IVF-Flat metric function definition" + ) + + func = function_defs[0] + func.decorator_list = [] + if not captures and _uses_name(func, "ctx"): + raise NotImplementedError( + "ivf_flat.metric ctx access requires an explicit capture" + ) + + if captures: + func = _CtxLoweringTransformer( + {capture.name for capture in captures} + ).visit(func) + if _uses_name(func, "ctx"): + raise NotImplementedError( + "ivf_flat.metric ctx may only be used as ctx.dim or " + "ctx." + ) + + func.name = symbol_name + args = [ast.arg(arg="x"), ast.arg(arg="y"), ast.arg(arg="acc")] + for capture in captures: + args.append(ast.arg(arg=capture.name)) + args.append(ast.arg(arg="dim")) + func.args = ast.arguments( + posonlyargs=[], + args=args, + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + ) + tree.body = [func] + ast.fix_missing_locations(tree) + + lowered_source = ast.unparse(tree) + namespace = dict(fn.__globals__) + exec( + compile(tree, f"", "exec"), + namespace, + ) + lowered_fn = namespace[symbol_name] + lowered_fn.__module__ = fn.__module__ + lowered_fn.__qualname__ = symbol_name + return lowered_fn, lowered_source + + +def _validate_captures( + captures: Mapping[str, Any] | None, +) -> tuple[CaptureInfo, ...]: + if not captures: + return () + if len(captures) > 1: + raise NotImplementedError( + "ivf_flat.metric currently supports at most one capture" + ) + + capture_infos = [] + for name, value in captures.items(): + if not isinstance(name, str) or not _CAPTURE_NAME_RE.match(name): + raise TypeError("capture names must be valid Python identifiers") + if keyword.iskeyword(name) or name in {"x", "y", "acc", "ctx", "dim"}: + raise ValueError(f"capture name {name!r} is reserved") + + capture = validate_capture(name, value) + if capture.dtype != "float32": + raise NotImplementedError( + "ivf_flat.metric currently supports only float32 captures" + ) + if len(capture.shape) != 1: + raise UnsupportedUDFCapture( + "ivf_flat.metric captures must be one-dimensional" + ) + if capture.strides not in (None, (4,)): + raise UnsupportedUDFCapture( + "ivf_flat.metric captures must be contiguous" + ) + capture_infos.append(capture) + return tuple(capture_infos) + + +def _validate_capture_devices( + captures: tuple[CaptureInfo, ...], expected_device: int | None +) -> None: + if expected_device is None: + return + for capture in captures: + if capture.device_id != expected_device: + raise UnsupportedUDFCapture( + "capture device must match cuVS resource device" + ) + + +def _to_udf_capture(capture: CaptureInfo) -> UDFCapture: + return UDFCapture( + name=capture.name, + dtype=capture.dtype, + shape=capture.shape, + strides=capture.strides, + device_id=capture.device_id, + readonly=capture.readonly, + pointer=capture.pointer, + owner=capture.owner, + ) + + +def _capture_arg_type(backend: Any, capture: CaptureInfo) -> Any: + if capture.dtype == "float32": + return backend.float32_pointer() + raise NotImplementedError( + f"unsupported ivf_flat.metric capture dtype {capture.dtype!r}" + ) + + +def _current_device_id(cuda: Any) -> int | None: + try: + device = cuda.get_current_device() + except Exception: + return None + + for attr in ("id", "device_id", "ordinal"): + value = getattr(device, attr, None) + if value is not None: + return int(value) + return 0 + + +def _function_source(fn: Callable[..., Any]) -> str: + return textwrap.dedent(inspect.getsource(fn)) + + +def _default_symbol_name(fn: Callable[..., Any], source: str) -> str: + stem = re.sub(r"[^0-9A-Za-z_]+", "_", fn.__qualname__).strip("_") + if not stem: + stem = "metric" + if stem[0].isdigit(): + stem = f"f_{stem}" + return f"cuvs_py_ivf_flat_metric_{stem}_{source_hash(source)[:16]}" + + +def _runtime_cuda_source_target() -> UDFTarget: + return UDFTarget( + sm="runtime", + cuda_version="runtime", + nvrtc_version="runtime", + nvjitlink_version="runtime", + numba_cuda_mlir_version=None, + compile_options=("nvrtc-lto",), + ) + + +def _payload_bytes(payload: str | bytes) -> bytes: + if isinstance(payload, bytes): + return payload + return payload.encode("utf-8") + + +def _uses_name(node: ast.AST, name: str) -> bool: + finder = _NameUseFinder(name) + finder.visit(node) + return finder.found + + +class _NameUseFinder(ast.NodeVisitor): + def __init__(self, name: str): + self.name = name + self.found = False + + def visit_Name(self, node: ast.Name) -> None: + if node.id == self.name: + self.found = True + + +class _CtxLoweringTransformer(ast.NodeTransformer): + def __init__(self, capture_names: set[str]): + self.capture_names = capture_names + + def visit_Attribute(self, node: ast.Attribute) -> ast.AST: + self.generic_visit(node) + if isinstance(node.value, ast.Name) and node.value.id == "ctx": + if node.attr == "dim": + return ast.copy_location( + ast.Name(id="dim", ctx=node.ctx), node + ) + if node.attr in self.capture_names: + return ast.copy_location( + ast.Name(id=node.attr, ctx=node.ctx), node + ) + raise UnsupportedUDFCapture( + f"ctx.{node.attr} has no matching capture" + ) + return node diff --git a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pxd b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pxd index b450151332..417462ac4a 100644 --- a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pxd +++ b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pxd @@ -4,7 +4,8 @@ # # cython: language_level=3 -from libc.stdint cimport int64_t, uint32_t, uintptr_t +from libc.stddef cimport size_t +from libc.stdint cimport int32_t, int64_t, uint32_t, uintptr_t from libcpp cimport bool from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t @@ -13,6 +14,37 @@ from cuvs.distance_type cimport cuvsDistanceType from cuvs.neighbors.filters.filters cimport cuvsFilter + +cdef extern from "cuvs/core/device_udf.h" nogil: + + ctypedef enum cuvsDeviceUDFPayloadKind: + CUVS_DEVICE_UDF_PAYLOAD_LTOIR + CUVS_DEVICE_UDF_PAYLOAD_CUDA_SOURCE + + cdef enum: + CUVS_UDF_CAPTURE_READONLY + + ctypedef struct cuvsUDFCapture: + const char* name + const char* dtype + const int64_t* shape + const int64_t* strides + int32_t ndim + int32_t device_id + uintptr_t pointer + uint32_t flags + + ctypedef struct cuvsDeviceUDF: + const char* abi + cuvsDeviceUDFPayloadKind payload_kind + const void* payload + size_t payload_size + const char* symbol_name + const cuvsUDFCapture* captures + size_t n_captures + const char* cache_key + uint32_t flags + cdef extern from "cuvs/neighbors/ivf_flat.h" nogil: ctypedef struct cuvsIvfFlatIndexParams: @@ -29,6 +61,7 @@ cdef extern from "cuvs/neighbors/ivf_flat.h" nogil: ctypedef struct cuvsIvfFlatSearchParams: uint32_t n_probes + const cuvsDeviceUDF* metric_udf ctypedef cuvsIvfFlatSearchParams* cuvsIvfFlatSearchParams_t @@ -91,3 +124,14 @@ cdef class IndexParams: cdef class SearchParams: cdef cuvsIvfFlatSearchParams* params + cdef object _metric + cdef bytes _metric_payload + cdef bytes _metric_abi + cdef bytes _metric_symbol_name + cdef bytes _metric_cache_key + cdef bytes _metric_capture_0_name + cdef bytes _metric_capture_0_dtype + cdef cuvsDeviceUDF _metric_udf_desc + cdef cuvsUDFCapture _metric_captures[1] + cdef int64_t _metric_capture_0_shape[8] + cdef int64_t _metric_capture_0_strides[8] diff --git a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx index 41788d1db0..321da33667 100644 --- a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx @@ -28,6 +28,7 @@ from cuvs.neighbors.filters import no_filter from libc.stdint cimport ( int8_t, + int32_t, int64_t, uint8_t, uint32_t, @@ -36,6 +37,7 @@ from libc.stdint cimport ( ) from cuvs.common.exceptions import check_cuvs +from cuvs._lib.device_udf import UDFArtifact cdef class IndexParams: @@ -270,17 +272,128 @@ cdef class SearchParams: ---------- n_probes: int The number of clusters to search. + metric: cuvs._lib.device_udf.UDFArtifact, optional + LTO-IR or CUDA source custom metric artifact for IVF-Flat. """ def __cinit__(self): cuvsIvfFlatSearchParamsCreate(&self.params) + self._metric = None + self._metric_payload = b"" + self._metric_abi = b"" + self._metric_symbol_name = b"" + self._metric_cache_key = b"" + self._metric_capture_0_name = b"" + self._metric_capture_0_dtype = b"" def __dealloc__(self): if self.params != NULL: check_cuvs(cuvsIvfFlatSearchParamsDestroy(self.params)) - def __init__(self, *, n_probes=20): + def __init__(self, *, n_probes=20, metric=None): self.params.n_probes = n_probes + self._clear_metric() + if metric is not None: + self._set_metric(metric) + + def _clear_metric(self): + self._metric = None + self._metric_payload = b"" + self._metric_abi = b"" + self._metric_symbol_name = b"" + self._metric_cache_key = b"" + self._metric_capture_0_name = b"" + self._metric_capture_0_dtype = b"" + self.params.metric_udf = NULL + self._metric_udf_desc.abi = NULL + self._metric_udf_desc.payload_kind = CUVS_DEVICE_UDF_PAYLOAD_LTOIR + self._metric_udf_desc.payload = NULL + self._metric_udf_desc.payload_size = 0 + self._metric_udf_desc.symbol_name = NULL + self._metric_udf_desc.captures = NULL + self._metric_udf_desc.n_captures = 0 + self._metric_udf_desc.cache_key = NULL + self._metric_udf_desc.flags = 0 + + def _set_metric(self, metric): + cdef Py_ssize_t i + cdef Py_ssize_t ndim + cdef const int64_t* strides_ptr + + if not isinstance(metric, UDFArtifact): + raise TypeError("metric must be a cuvs._lib.device_udf.UDFArtifact") + if metric.payload_kind not in ("ltoir", "cuda_source"): + raise TypeError( + "ivf_flat.SearchParams metric requires an LTO-IR or CUDA source UDF artifact" + ) + if metric.abi != "rapids.cuvs.ivf_flat.metric.v1": + raise ValueError("ivf_flat.SearchParams metric artifact has an incompatible ABI") + if metric.payload_kind == "cuda_source" and len(metric.captures) > 0: + raise NotImplementedError( + "ivf_flat.SearchParams CUDA source metric artifacts do not support captures" + ) + if len(metric.captures) > 1: + raise NotImplementedError( + "ivf_flat.SearchParams metric artifacts currently support at most one capture" + ) + + self._metric_payload = metric.payload_bytes() + self._metric_abi = metric.abi.encode("utf-8") + self._metric_symbol_name = metric.symbol_name.encode("utf-8") + self._metric_cache_key = metric.cache_key.encode("utf-8") + if not self._metric_payload: + raise ValueError("metric artifact payload must not be empty") + if not self._metric_symbol_name: + raise ValueError("metric artifact symbol_name must not be empty") + if not self._metric_cache_key: + raise ValueError("metric artifact cache_key must not be empty") + + self._metric = metric + self._metric_udf_desc.abi = self._metric_abi + if metric.payload_kind == "cuda_source": + self._metric_udf_desc.payload_kind = CUVS_DEVICE_UDF_PAYLOAD_CUDA_SOURCE + else: + self._metric_udf_desc.payload_kind = CUVS_DEVICE_UDF_PAYLOAD_LTOIR + self._metric_udf_desc.payload = self._metric_payload + self._metric_udf_desc.payload_size = len(self._metric_payload) + self._metric_udf_desc.symbol_name = self._metric_symbol_name + self._metric_udf_desc.captures = NULL + self._metric_udf_desc.n_captures = 0 + self._metric_udf_desc.cache_key = self._metric_cache_key + self._metric_udf_desc.flags = 0 + + if metric.captures: + capture = metric.captures[0] + ndim = len(capture.shape) + if ndim > 8: + raise NotImplementedError( + "ivf_flat.SearchParams metric captures currently support at most 8 dimensions" + ) + self._metric_capture_0_name = capture.name.encode("utf-8") + self._metric_capture_0_dtype = capture.dtype.encode("utf-8") + for i in range(ndim): + self._metric_capture_0_shape[i] = capture.shape[i] + + strides_ptr = NULL + if capture.strides is not None: + for i in range(ndim): + self._metric_capture_0_strides[i] = capture.strides[i] + strides_ptr = &self._metric_capture_0_strides[0] + + self._metric_captures[0].name = self._metric_capture_0_name + self._metric_captures[0].dtype = self._metric_capture_0_dtype + self._metric_captures[0].shape = &self._metric_capture_0_shape[0] + self._metric_captures[0].strides = strides_ptr + self._metric_captures[0].ndim = ndim + self._metric_captures[0].device_id = capture.device_id + self._metric_captures[0].pointer = capture.pointer + self._metric_captures[0].flags = 0 + if capture.readonly: + self._metric_captures[0].flags = CUVS_UDF_CAPTURE_READONLY + self._metric_udf_desc.captures = &self._metric_captures[0] + self._metric_udf_desc.n_captures = 1 + + self.params.metric_udf = &self._metric_udf_desc def get_handle(self): return self.params @@ -289,6 +402,10 @@ cdef class SearchParams: def n_probes(self): return self.params.n_probes + @property + def metric(self): + return self._metric + @auto_sync_resources @auto_convert_output diff --git a/python/cuvs/cuvs/tests/test_device_udf.py b/python/cuvs/cuvs/tests/test_device_udf.py new file mode 100644 index 0000000000..337427f580 --- /dev/null +++ b/python/cuvs/cuvs/tests/test_device_udf.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from cuvs._lib.device_udf import ( + UDFCapture, + UDFTarget, + build_cache_key, + make_udf_artifact, + source_hash, +) + + +def _target(**kwargs): + values = { + "sm": "120", + "cuda_version": "13.2", + "nvrtc_version": "13.2", + "nvjitlink_version": "13.2", + "numba_cuda_mlir_version": "0.3.0", + "compile_options": ("-lto",), + } + values.update(kwargs) + return UDFTarget(**values) + + +def _capture(**kwargs): + values = { + "name": "weights", + "dtype": "float32", + "shape": (128,), + "strides": None, + "device_id": 0, + "readonly": True, + "pointer": 1234, + "owner": object(), + } + values.update(kwargs) + return UDFCapture(**values) + + +def _key(**kwargs): + values = { + "abi": "rapids.cuvs.ivf_flat.metric.v1", + "payload_kind": "ltoir", + "payload": b"fake-ltoir", + "target": _target(), + "captures": (_capture(),), + "source_hash": source_hash("def f(x, y, acc): return acc + (x - y) * (x - y)"), + "lowering_version": "ctx-lowering-v1", + "algorithm_options": {"adapter": "ivf_flat", "dtype": "float32"}, + } + values.update(kwargs) + return build_cache_key(**values) + + +def _artifact(**kwargs): + values = { + "abi": "rapids.cuvs.ivf_flat.metric.v1", + "payload_kind": "ltoir", + "payload": b"fake-ltoir", + "symbol_name": "cuvs_l2_update_f32", + "captures": (), + "target": _target(), + "source_hash": source_hash("def f(x, y, acc): return acc"), + "lowering_version": "ctx-lowering-v1", + "algorithm_options": {"adapter": "ivf_flat", "dtype": "float32"}, + } + values.update(kwargs) + return make_udf_artifact(**values) + + +def test_cache_key_is_stable_for_same_metadata(): + assert _key() == _key() + + +def test_cache_key_excludes_capture_pointer_and_owner(): + left = _capture(pointer=1234, owner=object()) + right = _capture(pointer=987654, owner=object()) + + assert _key(captures=(left,)) == _key(captures=(right,)) + + +def test_cache_key_includes_capture_metadata(): + assert _key(captures=(_capture(dtype="float32"),)) != _key( + captures=(_capture(dtype="int32"),) + ) + assert _key(captures=(_capture(shape=(128,)),)) != _key( + captures=(_capture(shape=(256,)),) + ) + + +def test_cache_key_includes_payload_target_and_algorithm_options(): + base = _key() + assert base != _key(payload=b"different-ltoir") + assert base != _key(target=_target(sm="100")) + assert base != _key(algorithm_options={"adapter": "ivf_flat", "dtype": "int32"}) + + +def test_make_udf_artifact_sets_cache_key_and_descriptor_metadata(): + artifact = make_udf_artifact( + abi="rapids.cuvs.ivf_flat.metric.v1", + payload_kind="ltoir", + payload=b"fake-ltoir", + symbol_name="cuvs_l2_update_f32", + captures=(_capture(pointer=42),), + target=_target(), + source_hash=source_hash("def f(x, y, acc): return acc"), + lowering_version="ctx-lowering-v1", + algorithm_options={"adapter": "ivf_flat"}, + ) + + assert artifact.cache_key.startswith("rapids.cuvs.device_udf.cache_key.v1:") + assert artifact.payload_bytes() == b"fake-ltoir" + + metadata = artifact.c_descriptor_metadata() + assert metadata["payload_size"] == len(b"fake-ltoir") + assert metadata["captures"][0]["pointer"] == 42 + assert "owner" not in metadata["captures"][0] + + +def test_ivf_flat_cuda_source_metric_builds_cuda_source_artifact(): + from cuvs.neighbors import ivf_flat + + artifact = ivf_flat.cuda_source_metric( + 'extern "C" __device__ float f(float x, float y, float acc);', + symbol_name="cuvs_test_source_metric", + ) + + assert artifact.abi == "rapids.cuvs.ivf_flat.metric.v1" + assert artifact.payload_kind == "cuda_source" + assert artifact.symbol_name == "cuvs_test_source_metric" + assert artifact.payload_bytes().startswith(b'extern "C"') + assert artifact.target.sm == "runtime" + + +def test_ivf_flat_search_params_accepts_cuda_source_udf_artifact(): + from cuvs.neighbors import ivf_flat + + artifact = _artifact( + payload_kind="cuda_source", + payload='extern "C" __device__ float f(float x, float y, float acc);', + ) + + params = ivf_flat.SearchParams(metric=artifact) + + assert params.metric is artifact + + +def test_ivf_flat_search_params_rejects_cuda_source_captures(): + from cuvs.neighbors import ivf_flat + + artifact = _artifact( + payload_kind="cuda_source", + payload='extern "C" __device__ float f(float x, float y, float acc);', + captures=(_capture(name="weights", pointer=1234),), + ) + + with pytest.raises(NotImplementedError, match="do not support captures"): + ivf_flat.SearchParams(metric=artifact) + + +def test_ivf_flat_search_params_rejects_wrong_abi(): + from cuvs.neighbors import ivf_flat + + artifact = _artifact(abi="rapids.cuvs.ivf_pq.metric.v1") + + with pytest.raises(ValueError, match="incompatible ABI"): + ivf_flat.SearchParams(metric=artifact) + + +def test_ivf_flat_search_params_rejects_too_many_captures(): + from cuvs.neighbors import ivf_flat + + artifact = _artifact( + captures=( + _capture(name="weights", pointer=1234), + _capture(name="bias", pointer=5678), + ) + ) + + with pytest.raises(NotImplementedError, match="at most one capture"): + ivf_flat.SearchParams(metric=artifact) + + +def test_invalid_payload_kind_fails(): + with pytest.raises(ValueError, match="payload_kind"): + _key(payload_kind="ptx") diff --git a/python/cuvs/cuvs/tests/test_ivf_flat.py b/python/cuvs/cuvs/tests/test_ivf_flat.py index 9435a1289c..cc0ba420f6 100644 --- a/python/cuvs/cuvs/tests/test_ivf_flat.py +++ b/python/cuvs/cuvs/tests/test_ivf_flat.py @@ -18,6 +18,7 @@ ) + def run_ivf_flat_build_search_test( n_rows=10000, n_cols=10, @@ -142,3 +143,203 @@ def test_extend(dtype, serialize): @pytest.mark.parametrize("sparsity", [0.5, 0.7, 1.0]) def test_filtered_ivf_flat(sparsity): run_filtered_search_test(ivf_flat, sparsity) + + +def test_ivf_flat_numba_cuda_mlir_ltoir_udf_matches_builtin_l2(): + cp = pytest.importorskip("cupy") + pytest.importorskip("numba_cuda_mlir") + from numba_cuda_mlir import cuda + + if not cuda.is_available(): + pytest.skip("CUDA is not available to numba_cuda_mlir") + + with cp.cuda.Device(0): + + @ivf_flat.metric( + order="min", + initial=0.0, + coarse_metric="sqeuclidean", + symbol_name="cuvs_py_ivf_flat_l2_update_f32_test", + ) + def l2_update(x, y, acc, ctx): + d = x - y + return acc + d * d + + dataset = cp.asarray( + [ + [0.0, 0.0, 0.0], + [1.0, 0.5, -0.5], + [2.0, -1.0, 0.25], + [-1.5, 2.0, 1.0], + [3.0, 1.5, -2.0], + [-2.0, -1.0, 2.5], + ], + dtype=cp.float32, + ) + queries = cp.asarray( + [[0.2, 0.1, -0.1], [2.1, -0.7, 0.4]], + dtype=cp.float32, + ) + + index = ivf_flat.build( + ivf_flat.IndexParams(n_lists=1, metric="sqeuclidean"), + dataset, + ) + builtin_distances, builtin_neighbors = ivf_flat.search( + ivf_flat.SearchParams(n_probes=1), + index, + queries, + 3, + ) + udf_distances, udf_neighbors = ivf_flat.search( + ivf_flat.SearchParams(n_probes=1, metric=l2_update), + index, + queries, + 3, + ) + + cp.testing.assert_allclose( + cp.asarray(udf_distances), + cp.asarray(builtin_distances), + rtol=1e-5, + atol=1e-5, + ) + cp.testing.assert_array_equal( + cp.asarray(udf_neighbors), cp.asarray(builtin_neighbors) + ) + + +def test_ivf_flat_cuda_source_metric_matches_builtin_l2(): + cp = pytest.importorskip("cupy") + + with cp.cuda.Device(0): + source = r""" +namespace cuvs::neighbors::ivf_flat::detail { +template +__device__ __forceinline__ void compute_dist_udf_impl( + AccT& acc, AccT x, AccT y) +{ + auto d = x - y; + acc += d * d; +} +} +""" + source_metric = ivf_flat.cuda_source_metric( + source, + symbol_name="cuvs_py_ivf_flat_cuda_source_l2_test", + ) + + dataset = cp.asarray( + [ + [0.0, 0.0, 0.0], + [1.0, 0.5, -0.5], + [2.0, -1.0, 0.25], + [-1.5, 2.0, 1.0], + [3.0, 1.5, -2.0], + [-2.0, -1.0, 2.5], + ], + dtype=cp.float32, + ) + queries = cp.asarray( + [[0.2, 0.1, -0.1], [2.1, -0.7, 0.4]], + dtype=cp.float32, + ) + + index = ivf_flat.build( + ivf_flat.IndexParams(n_lists=1, metric="sqeuclidean"), + dataset, + ) + builtin_distances, builtin_neighbors = ivf_flat.search( + ivf_flat.SearchParams(n_probes=1), + index, + queries, + 3, + ) + source_distances, source_neighbors = ivf_flat.search( + ivf_flat.SearchParams(n_probes=1, metric=source_metric), + index, + queries, + 3, + ) + + cp.testing.assert_allclose( + cp.asarray(source_distances), + cp.asarray(builtin_distances), + rtol=1e-5, + atol=1e-5, + ) + cp.testing.assert_array_equal( + cp.asarray(source_neighbors), cp.asarray(builtin_neighbors) + ) + + +def test_ivf_flat_ltoir_weighted_l2_capture_matches_reference(): + cp = pytest.importorskip("cupy") + pytest.importorskip("numba_cuda_mlir") + from numba_cuda_mlir import cuda + + if not cuda.is_available(): + pytest.skip("CUDA is not available to numba_cuda_mlir") + + with cp.cuda.Device(0): + weights = cp.asarray([0.25, 1.5, 3.0, 0.75], dtype=cp.float32) + + @ivf_flat.metric( + order="min", + initial=0.0, + coarse_metric="sqeuclidean", + captures={"weights": weights}, + symbol_name="cuvs_py_ivf_flat_weighted_l2_update_f32_test", + ) + def weighted_l2_update(x, y, acc, ctx): + d = x - y + return acc + ctx.weights[ctx.dim] * d * d + + dataset = cp.asarray( + [ + [0.0, 0.0, 0.0, 0.0], + [1.0, 0.5, -0.5, 2.0], + [2.0, -1.0, 0.25, -0.5], + [-1.5, 2.0, 1.0, 0.75], + [3.0, 1.5, -2.0, -1.0], + [-2.0, -1.0, 2.5, 1.25], + ], + dtype=cp.float32, + ) + queries = cp.asarray( + [[0.2, 0.1, -0.1, 0.5], [2.1, -0.7, 0.4, -0.25]], + dtype=cp.float32, + ) + + index = ivf_flat.build( + ivf_flat.IndexParams(n_lists=1, metric="sqeuclidean"), + dataset, + ) + udf_distances, udf_neighbors = ivf_flat.search( + ivf_flat.SearchParams(n_probes=1, metric=weighted_l2_update), + index, + queries, + 3, + ) + + diff = queries[:, None, :] - dataset[None, :, :] + reference_distances = cp.sum( + weights[None, None, :] * diff * diff, axis=2 + ) + reference_neighbors = cp.argsort( + reference_distances, axis=1 + )[:, :3].astype(cp.int64) + reference_top_distances = cp.take_along_axis( + reference_distances, reference_neighbors, axis=1 + ) + + cp.testing.assert_allclose( + cp.asarray(udf_distances), + reference_top_distances, + rtol=1e-5, + atol=1e-5, + ) + cp.testing.assert_array_equal( + cp.asarray(udf_neighbors), reference_neighbors + ) + diff --git a/python/cuvs/cuvs/tests/test_ivf_flat_metric.py b/python/cuvs/cuvs/tests/test_ivf_flat_metric.py new file mode 100644 index 0000000000..4cf29c8a2b --- /dev/null +++ b/python/cuvs/cuvs/tests/test_ivf_flat_metric.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from cuvs._lib.udf_validation import UnsupportedUDFCapture, validate_capture +from cuvs.neighbors import ivf_flat +from cuvs.neighbors.ivf_flat._udf import _function_source, _lower_metric + + +class FakeCudaArray: + def __init__( + self, *, typestr=" 1024 + assert artifact.cache_key.startswith("rapids.cuvs.device_udf.cache_key.v1:") + assert result.return_type == backend.float32() + + + +class FakeType: + def __init__(self, name): + self.name = name + + def __call__(self, *arg_types): + return (self, arg_types) + + def __repr__(self): + return self.name + + +class FakeTypes: + float32 = FakeType("float32") + int64 = FakeType("int64") + + +class FakeCuda: + def compile(self, *args, **kwargs): + return b"fake-ltoir", FakeTypes.int64 + + +def test_numba_cuda_mlir_backend_rejects_return_type_mismatch(): + backend = NumbaCudaMLIRBackend(cuda_module=FakeCuda(), types_module=FakeTypes()) + target = UDFTarget( + sm="120", + cuda_version="13.2", + nvrtc_version=None, + nvjitlink_version="13.2", + numba_cuda_mlir_version="0.3.0", + compile_options=("-lto",), + ) + + with pytest.raises(TypeError, match="return type"): + backend.compile( + l2_update, + abi="rapids.cuvs.ivf_flat.metric.v1", + symbol_name="cuvs_l2_update_return_mismatch", + arg_types=(FakeTypes.float32, FakeTypes.float32, FakeTypes.float32), + return_type=FakeTypes.float32, + target=target, + lowering_version="explicit-lowered-v0", + ) diff --git a/python/cuvs/cuvs/tests/test_udf_validation.py b/python/cuvs/cuvs/tests/test_udf_validation.py new file mode 100644 index 0000000000..56bcf7c9ea --- /dev/null +++ b/python/cuvs/cuvs/tests/test_udf_validation.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from cuvs._lib.udf_validation import ( + UnsupportedUDFCapture, + UnsupportedUDFSyntax, + validate_capture, + validate_ivf_flat_metadata, + validate_signature, + validate_udf_policy, +) + + +def valid_l2(x, y, acc, ctx): + d = x - y + return acc + d * d + + +def valid_ctx_read(x, y, acc, ctx): + d = x - y + return acc + ctx.weights[ctx.dim] * d * d + + +def invalid_call(x, y, acc, ctx): + return acc + ctx.np.sqrt(x - y) + + +def invalid_ctx_write(x, y, acc, ctx): + ctx.weights[ctx.dim] = 1.0 + return acc + + +def invalid_for_iter(x, y, acc, ctx): + for item in ctx.weights: + acc = acc + item + return acc + + +def valid_range_loop(x, y, acc, ctx): + for i in range(2): + acc = acc + i * (x - y) + return acc + + +def test_validate_signature_accepts_expected_positional_args(): + validate_signature(valid_l2, ["x", "y", "acc", "ctx"]) + + +def test_validate_signature_rejects_wrong_names_and_defaults(): + def wrong(a, b, c, d): + return a + b + c + d + + def default_arg(x, y, acc, ctx=None): + return acc + + with pytest.raises(TypeError, match="expected f"): + validate_signature(wrong, ["x", "y", "acc", "ctx"]) + + with pytest.raises(TypeError, match="defaults"): + validate_signature(default_arg, ["x", "y", "acc", "ctx"]) + + +def test_validate_signature_rejects_varargs_and_keyword_only(): + def varargs(x, y, acc, ctx, *rest): + return acc + + def keyword_only(x, y, acc, *, ctx): + return acc + + with pytest.raises(TypeError, match="keyword-only"): + validate_signature(varargs, ["x", "y", "acc", "ctx", "rest"]) + + with pytest.raises(TypeError, match="keyword-only"): + validate_signature(keyword_only, ["x", "y", "acc", "ctx"]) + + +def test_validate_ivf_flat_metadata(): + validate_ivf_flat_metadata("min", 0.0, "sqeuclidean") + + with pytest.raises(ValueError, match="order"): + validate_ivf_flat_metadata("max", 0.0, "sqeuclidean") + + with pytest.raises(TypeError, match="initial"): + validate_ivf_flat_metadata("min", object(), "sqeuclidean") + + with pytest.raises(ValueError, match="coarse_metric"): + validate_ivf_flat_metadata("min", 0.0, "cosine") + + +def test_validate_udf_policy_accepts_basic_subset(): + validate_udf_policy(valid_l2) + validate_udf_policy(valid_ctx_read) + validate_udf_policy(valid_range_loop) + + +def test_validate_udf_policy_rejects_arbitrary_call(): + with pytest.raises(UnsupportedUDFSyntax, match="ctx.np.sqrt"): + validate_udf_policy(invalid_call) + + +def test_validate_udf_policy_rejects_capture_mutation(): + with pytest.raises(UnsupportedUDFSyntax, match="read-only"): + validate_udf_policy(invalid_ctx_write) + + +def test_validate_udf_policy_rejects_non_range_for_loop(): + with pytest.raises(UnsupportedUDFSyntax, match="range"): + validate_udf_policy(invalid_for_iter) + + +def test_validate_capture_accepts_cupy_cuda_array_interface(): + cp = pytest.importorskip("cupy") + + with cp.cuda.Device(0): + owner = cp.arange(8, dtype=cp.float32) + capture = validate_capture("weights", owner, expected_device=0) + + assert capture.name == "weights" + assert capture.dtype == "float32" + assert capture.shape == (8,) + # CuPy omits strides from CUDA Array Interface for contiguous arrays. + assert capture.strides is None + assert capture.pointer == owner.data.ptr + assert capture.owner is owner + assert capture.readonly is True + + +def test_validate_capture_accepts_cupy_strided_view(): + cp = pytest.importorskip("cupy") + + with cp.cuda.Device(0): + owner = cp.arange(16, dtype=cp.float32)[::2] + capture = validate_capture("weights", owner, expected_device=0) + + assert capture.shape == owner.shape + assert capture.strides == owner.strides + assert capture.pointer == owner.data.ptr + + +def test_validate_capture_rejects_missing_cai(): + with pytest.raises(UnsupportedUDFCapture, match="CUDA Array Interface"): + validate_capture("weights", object()) + + +def test_validate_capture_rejects_unsupported_cupy_dtype(): + cp = pytest.importorskip("cupy") + + with cp.cuda.Device(0): + owner = cp.arange(8, dtype=cp.float64) + + with pytest.raises(UnsupportedUDFCapture, match="unsupported capture dtype"): + validate_capture("weights", owner, expected_device=0) + + +def test_validate_capture_rejects_wrong_expected_device(): + cp = pytest.importorskip("cupy") + + with cp.cuda.Device(0): + owner = cp.arange(8, dtype=cp.float32) + + with pytest.raises(UnsupportedUDFCapture, match="device"): + validate_capture("weights", owner, expected_device=1) From 8f34ac0e792fa5bc1ed6816074af3970d9d72b18 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Tue, 26 May 2026 23:04:19 -0500 Subject: [PATCH 2/2] ENH add simple example e2e --- .../experimental/ivf_flat_udf_e2e_demo.py | 230 ++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100644 examples/experimental/ivf_flat_udf_e2e_demo.py diff --git a/examples/experimental/ivf_flat_udf_e2e_demo.py b/examples/experimental/ivf_flat_udf_e2e_demo.py new file mode 100644 index 0000000000..70a58a4813 --- /dev/null +++ b/examples/experimental/ivf_flat_udf_e2e_demo.py @@ -0,0 +1,230 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end IVF Flat custom metric UDF demo. + +This PoC demonstrates two custom-metric paths: + +1. A Python metric compiled to LTO-IR by numba-cuda-mlir, with a CuPy capture. +2. An expert CUDA/C++ source-string metric using the existing JIT/LTO UDF path. +""" + +from __future__ import annotations + +import importlib.util +import sys + + +def _require_module(name): + if importlib.util.find_spec(name) is None: + raise RuntimeError(f"required module {name!r} is not available") + + +def _print_banner(title): + print() + print("=" * 88) + print(title) + print("=" * 88) + + +def _asnumpy(cp, value): + return cp.asnumpy(cp.asarray(value)) + + +def _format_array(cp, value): + return str(_asnumpy(cp, value)) + + +def _weighted_l2_reference(cp, queries, dataset, weights, k): + diff = queries[:, None, :] - dataset[None, :, :] + distances = cp.sum(weights[None, None, :] * diff * diff, axis=2) + neighbors = cp.argsort(distances, axis=1)[:, :k].astype(cp.int64) + top_distances = cp.take_along_axis(distances, neighbors, axis=1) + return top_distances, neighbors + + +def run_python_metric_demo(cp, ivf_flat): + _print_banner("Example 1: Python @ivf_flat.metric weighted L2 with a CuPy capture") + + weights = cp.asarray([0.25, 1.5, 3.0, 0.75], dtype=cp.float32) + + @ivf_flat.metric( + order="min", + initial=0.0, + coarse_metric="sqeuclidean", + captures={"weights": weights}, + symbol_name="cuvs_demo_weighted_l2_update_f32", + ) + def weighted_l2_update(x, y, acc, ctx): + d = x - y + return acc + ctx.weights[ctx.dim] * d * d + + dataset = cp.asarray( + [ + [0.0, 0.0, 0.0, 0.0], + [1.0, 0.5, -0.5, 2.0], + [2.0, -1.0, 0.25, -0.5], + [-1.5, 2.0, 1.0, 0.75], + [3.0, 1.5, -2.0, -1.0], + [-2.0, -1.0, 2.5, 1.25], + ], + dtype=cp.float32, + ) + queries = cp.asarray( + [[0.2, 0.1, -0.1, 0.5], [2.1, -0.7, 0.4, -0.25]], + dtype=cp.float32, + ) + + index = ivf_flat.build( + ivf_flat.IndexParams(n_lists=1, metric="sqeuclidean"), + dataset, + ) + udf_distances, udf_neighbors = ivf_flat.search( + ivf_flat.SearchParams(n_probes=1, metric=weighted_l2_update), + index, + queries, + 3, + ) + reference_distances, reference_neighbors = _weighted_l2_reference( + cp, queries, dataset, weights, 3 + ) + + max_error = cp.max(cp.abs(cp.asarray(udf_distances) - reference_distances)) + distances_match = cp.allclose( + cp.asarray(udf_distances), reference_distances, rtol=1e-5, atol=1e-5 + ) + neighbors_match = cp.array_equal( + cp.asarray(udf_neighbors), reference_neighbors + ) + passed = bool(distances_match and neighbors_match) + + print("weights:") + print(_format_array(cp, weights)) + print("queries:") + print(_format_array(cp, queries)) + print("dataset:") + print(_format_array(cp, dataset)) + print("UDF neighbors:") + print(_format_array(cp, udf_neighbors)) + print("reference neighbors:") + print(_format_array(cp, reference_neighbors)) + print("UDF distances:") + print(_format_array(cp, udf_distances)) + print("reference distances:") + print(_format_array(cp, reference_distances)) + print(f"max abs distance error: {float(max_error):.8f}") + print(f"neighbors match: {neighbors_match}") + print(f"RESULT: {'PASS' if passed else 'FAIL'}") + + return passed + + +def run_cuda_source_metric_demo(cp, ivf_flat): + _print_banner("Example 2: Expert CUDA/C++ source-string L2 metric") + + dataset = cp.asarray( + [ + [0.0, 0.0, 0.0], + [1.0, 0.5, -0.5], + [2.0, -1.0, 0.25], + [-1.5, 2.0, 1.0], + [3.0, 1.5, -2.0], + [-2.0, -1.0, 2.5], + ], + dtype=cp.float32, + ) + queries = cp.asarray( + [[0.2, 0.1, -0.1], [2.1, -0.7, 0.4]], + dtype=cp.float32, + ) + + source = r""" +namespace cuvs::neighbors::ivf_flat::detail { +template +__device__ __forceinline__ void compute_dist_udf_impl( + AccT& acc, AccT x, AccT y) +{ + auto d = x - y; + acc += d * d; +} +} +""" + source_metric = ivf_flat.cuda_source_metric( + source, + symbol_name="cuvs_demo_cuda_source_l2_metric", + ) + + index = ivf_flat.build( + ivf_flat.IndexParams(n_lists=1, metric="sqeuclidean"), + dataset, + ) + builtin_distances, builtin_neighbors = ivf_flat.search( + ivf_flat.SearchParams(n_probes=1), + index, + queries, + 3, + ) + source_distances, source_neighbors = ivf_flat.search( + ivf_flat.SearchParams(n_probes=1, metric=source_metric), + index, + queries, + 3, + ) + + max_error = cp.max(cp.abs(cp.asarray(source_distances) - builtin_distances)) + distances_match = cp.allclose( + cp.asarray(source_distances), + cp.asarray(builtin_distances), + rtol=1e-5, + atol=1e-5, + ) + neighbors_match = cp.array_equal( + cp.asarray(source_neighbors), cp.asarray(builtin_neighbors) + ) + passed = bool(distances_match and neighbors_match) + + print("queries:") + print(_format_array(cp, queries)) + print("dataset:") + print(_format_array(cp, dataset)) + print("CUDA source neighbors:") + print(_format_array(cp, source_neighbors)) + print("built-in L2 neighbors:") + print(_format_array(cp, builtin_neighbors)) + print("CUDA source distances:") + print(_format_array(cp, source_distances)) + print("built-in L2 distances:") + print(_format_array(cp, builtin_distances)) + print(f"max abs distance error: {float(max_error):.8f}") + print(f"neighbors match: {neighbors_match}") + print(f"RESULT: {'PASS' if passed else 'FAIL'}") + + return passed + + +def main(): + _require_module("cupy") + _require_module("numba_cuda_mlir") + + import cupy as cp + from cuvs.neighbors import ivf_flat + from numba_cuda_mlir import cuda + + if not cuda.is_available(): + raise RuntimeError("CUDA is not available to numba_cuda_mlir") + + cp.set_printoptions(precision=4, suppress=True) + + with cp.cuda.Device(0): + python_metric_ok = run_python_metric_demo(cp, ivf_flat) + cuda_source_ok = run_cuda_source_metric_demo(cp, ivf_flat) + + return 0 if python_metric_ok and cuda_source_ok else 1 + + +if __name__ == "__main__": + try: + sys.exit(main()) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) + sys.exit(1)