Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions c/include/cuvs/core/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <cuvs/core/c_config.h>
#include <cuvs/core/export.h>
#include <cuvs/core/device_udf.h>
#include <cuvs/core/c_api.h>

#include <cuvs/cluster/kmeans.h>
Expand Down
74 changes: 74 additions & 0 deletions c/include/cuvs/core/device_udf.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once
#ifndef CUVS_CORE_DEVICE_UDF_H
#define CUVS_CORE_DEVICE_UDF_H

#include <stddef.h>
#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

/** Device UDF payload representation. */
typedef enum cuvsDeviceUDFPayloadKind {
CUVS_DEVICE_UDF_PAYLOAD_LTOIR = 1,
CUVS_DEVICE_UDF_PAYLOAD_CUDA_SOURCE = 2,
} cuvsDeviceUDFPayloadKind;

/** Capture flags for cuvsUDFCapture. */
enum { CUVS_UDF_CAPTURE_READONLY = 1u };

/** Borrowed capture descriptor for a device UDF. */
typedef struct cuvsUDFCapture {
/** Capture name as used by the frontend, e.g. "weights". */
const char* name;
/** Logical dtype string, e.g. "float32". */
const char* dtype;
/** Optional shape array of length ndim. */
const int64_t* shape;
/** Optional strides array of length ndim, in bytes. Null means contiguous/default. */
const int64_t* strides;
/** Number of dimensions in shape/strides. */
int32_t ndim;
/** CUDA device ordinal for the capture allocation. */
int32_t device_id;
/** CUDA device pointer for the capture allocation. */
uintptr_t pointer;
/** Bitmask of CUVS_UDF_CAPTURE_* flags. */
uint32_t flags;
} cuvsUDFCapture;

/** Borrowed device UDF descriptor. Payload and captures are copied by consumers. */
typedef struct cuvsDeviceUDF {
/** ABI identifier, e.g. "rapids.cuvs.ivf_flat.metric.v1". */
const char* abi;
/** Payload representation. */
cuvsDeviceUDFPayloadKind payload_kind;
/** Borrowed payload bytes. */
const void* payload;
/** Number of bytes in payload. */
size_t payload_size;
/** Optional externally visible device symbol in payload. */
const char* symbol_name;
/** Borrowed capture descriptors. */
const cuvsUDFCapture* captures;
/** Number of capture descriptors. */
size_t n_captures;
/** Cache key used to identify this UDF artifact. */
const char* cache_key;
/** Reserved flags; must be zero for now. */
uint32_t flags;
} cuvsDeviceUDF;

typedef const cuvsDeviceUDF* cuvsDeviceUDF_t;

#ifdef __cplusplus
}
#endif

#endif // CUVS_CORE_DEVICE_UDF_H
6 changes: 6 additions & 0 deletions c/include/cuvs/neighbors/ivf_flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
#pragma once

#include <cuvs/core/c_api.h>
#include <cuvs/core/device_udf.h>
#include <cuvs/distance/distance.h>
#include <cuvs/neighbors/common.h>
#include <dlpack/dlpack.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>

#include <cuvs/core/export.h>
Expand Down Expand Up @@ -103,6 +105,8 @@ CUVS_EXPORT cuvsError_t cuvsIvfFlatIndexParamsDestroy(cuvsIvfFlatIndexParams_t i
struct cuvsIvfFlatSearchParams {
/** The number of clusters to search. */
uint32_t n_probes;
/** Optional borrowed device UDF descriptor for a custom metric. */
cuvsDeviceUDF_t metric_udf;
};

typedef struct cuvsIvfFlatSearchParams* cuvsIvfFlatSearchParams_t;
Expand Down Expand Up @@ -201,6 +205,7 @@ CUVS_EXPORT cuvsError_t cuvsIvfFlatIndexGetCenters(cuvsIvfFlatIndex_t index, DLM
*
* @code {.c}
* #include <cuvs/core/c_api.h>
* #include <cuvs/core/device_udf.h>
* #include <cuvs/neighbors/ivf_flat.h>
*
* // Create cuvsResources_t
Expand Down Expand Up @@ -257,6 +262,7 @@ CUVS_EXPORT cuvsError_t cuvsIvfFlatBuild(cuvsResources_t res,
*
* @code {.c}
* #include <cuvs/core/c_api.h>
* #include <cuvs/core/device_udf.h>
* #include <cuvs/neighbors/ivf_flat.h>
*
* // Create cuvsResources_t
Expand Down
30 changes: 28 additions & 2 deletions c/src/neighbors/ivf_flat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
*/

#include <cstdint>
#include <utility>

#include <dlpack/dlpack.h>

#include <raft/core/error.hpp>
Expand All @@ -21,6 +23,7 @@
#include "../core/interop.hpp"

#include <fstream>
#include <string>

namespace cuvs::neighbors::ivf_flat {
void convert_c_index_params(cuvsIvfFlatIndexParams params,
Expand All @@ -39,6 +42,28 @@ void convert_c_search_params(cuvsIvfFlatSearchParams params,
cuvs::neighbors::ivf_flat::search_params* out)
{
out->n_probes = params.n_probes;

if (params.metric_udf == nullptr) { return; }

auto metric_udf = cuvs::jit::make_device_udf(*params.metric_udf);
RAFT_EXPECTS(metric_udf.abi == "rapids.cuvs.ivf_flat.metric.v1",
"Unsupported IVF Flat metric UDF ABI: %s",
metric_udf.abi.c_str());

if (metric_udf.payload_kind == cuvs::jit::device_udf_payload_kind::cuda_source) {
RAFT_EXPECTS(metric_udf.captures.empty(),
"IVF Flat CUDA source metric UDF currently does not support captures");
out->metric_udf = std::string{reinterpret_cast<char const*>(metric_udf.payload.data()),
metric_udf.payload.size()};
return;
}

RAFT_EXPECTS(metric_udf.payload_kind == cuvs::jit::device_udf_payload_kind::ltoir,
"IVF Flat metric UDF currently requires an LTO-IR or CUDA source payload");
RAFT_EXPECTS(metric_udf.captures.size() <= 1,
"IVF Flat metric UDF currently supports at most one capture");

out->metric_ltoir_udf = std::move(metric_udf);
}
} // namespace cuvs::neighbors::ivf_flat

Expand Down Expand Up @@ -282,8 +307,9 @@ extern "C" cuvsError_t cuvsIvfFlatIndexParamsDestroy(cuvsIvfFlatIndexParams_t pa

extern "C" cuvsError_t cuvsIvfFlatSearchParamsCreate(cuvsIvfFlatSearchParams_t* params)
{
return cuvs::core::translate_exceptions(
[=] { *params = new cuvsIvfFlatSearchParams{.n_probes = 20}; });
return cuvs::core::translate_exceptions([=] {
*params = new cuvsIvfFlatSearchParams{.n_probes = 20, .metric_udf = nullptr};
});
}

extern "C" cuvsError_t cuvsIvfFlatSearchParamsDestroy(cuvsIvfFlatSearchParams_t params)
Expand Down
74 changes: 74 additions & 0 deletions cpp/include/cuvs/core/device_udf.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once
#ifndef CUVS_CORE_DEVICE_UDF_H
#define CUVS_CORE_DEVICE_UDF_H

#include <stddef.h>
#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

/** Device UDF payload representation. */
typedef enum cuvsDeviceUDFPayloadKind {
CUVS_DEVICE_UDF_PAYLOAD_LTOIR = 1,
CUVS_DEVICE_UDF_PAYLOAD_CUDA_SOURCE = 2,
} cuvsDeviceUDFPayloadKind;

/** Capture flags for cuvsUDFCapture. */
enum { CUVS_UDF_CAPTURE_READONLY = 1u };

/** Borrowed capture descriptor for a device UDF. */
typedef struct cuvsUDFCapture {
/** Capture name as used by the frontend, e.g. "weights". */
const char* name;
/** Logical dtype string, e.g. "float32". */
const char* dtype;
/** Optional shape array of length ndim. */
const int64_t* shape;
/** Optional strides array of length ndim, in bytes. Null means contiguous/default. */
const int64_t* strides;
/** Number of dimensions in shape/strides. */
int32_t ndim;
/** CUDA device ordinal for the capture allocation. */
int32_t device_id;
/** CUDA device pointer for the capture allocation. */
uintptr_t pointer;
/** Bitmask of CUVS_UDF_CAPTURE_* flags. */
uint32_t flags;
} cuvsUDFCapture;

/** Borrowed device UDF descriptor. Payload and captures are copied by consumers. */
typedef struct cuvsDeviceUDF {
/** ABI identifier, e.g. "rapids.cuvs.ivf_flat.metric.v1". */
const char* abi;
/** Payload representation. */
cuvsDeviceUDFPayloadKind payload_kind;
/** Borrowed payload bytes. */
const void* payload;
/** Number of bytes in payload. */
size_t payload_size;
/** Optional externally visible device symbol in payload. */
const char* symbol_name;
/** Borrowed capture descriptors. */
const cuvsUDFCapture* captures;
/** Number of capture descriptors. */
size_t n_captures;
/** Cache key used to identify this UDF artifact. */
const char* cache_key;
/** Reserved flags; must be zero for now. */
uint32_t flags;
} cuvsDeviceUDF;

typedef const cuvsDeviceUDF* cuvsDeviceUDF_t;

#ifdef __cplusplus
}
#endif

#endif // CUVS_CORE_DEVICE_UDF_H
101 changes: 101 additions & 0 deletions cpp/include/cuvs/core/device_udf.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cuvs/core/device_udf.h>

#include <cstddef>
#include <cstdint>
#include <string>
#include <vector>
#include <utility>

#include <raft/core/error.hpp>

namespace cuvs::jit {

enum class device_udf_payload_kind { ltoir, cuda_source };

struct udf_capture {
std::string name;
std::string dtype;
std::vector<int64_t> shape;
std::vector<int64_t> strides;
int32_t device_id = 0;
std::uintptr_t pointer = 0;
bool readonly = true;
};

struct device_udf {
std::string abi;
device_udf_payload_kind payload_kind = device_udf_payload_kind::ltoir;
std::vector<uint8_t> payload;
std::string symbol_name;
std::string cache_key;
std::vector<udf_capture> captures;
};

using ltoir_udf = device_udf;

inline device_udf_payload_kind payload_kind_from_c(cuvsDeviceUDFPayloadKind kind)
{
switch (kind) {
case CUVS_DEVICE_UDF_PAYLOAD_LTOIR: return device_udf_payload_kind::ltoir;
case CUVS_DEVICE_UDF_PAYLOAD_CUDA_SOURCE: return device_udf_payload_kind::cuda_source;
}
RAFT_FAIL("Unsupported cuVS device UDF payload kind: %d", static_cast<int>(kind));
return device_udf_payload_kind::ltoir;
}

inline device_udf make_device_udf(cuvsDeviceUDF const& desc)
{
RAFT_EXPECTS(desc.abi != nullptr, "device UDF abi must not be null");
RAFT_EXPECTS(desc.payload != nullptr, "device UDF payload must not be null");
RAFT_EXPECTS(desc.payload_size > 0, "device UDF payload_size must be non-zero");
RAFT_EXPECTS(desc.symbol_name != nullptr, "device UDF symbol_name must not be null");
RAFT_EXPECTS(desc.cache_key != nullptr, "device UDF cache_key must not be null");
RAFT_EXPECTS(desc.flags == 0, "device UDF flags must be zero");
RAFT_EXPECTS(desc.n_captures == 0 || desc.captures != nullptr,
"device UDF captures must not be null when n_captures is non-zero");

auto const* payload_begin = static_cast<std::uint8_t const*>(desc.payload);
auto out = device_udf{.abi = std::string{desc.abi},
.payload_kind = payload_kind_from_c(desc.payload_kind),
.payload = std::vector<std::uint8_t>{
payload_begin, payload_begin + desc.payload_size},
.symbol_name = std::string{desc.symbol_name},
.cache_key = std::string{desc.cache_key}};

out.captures.reserve(desc.n_captures);
for (size_t i = 0; i < desc.n_captures; ++i) {
auto const& capture = desc.captures[i];
RAFT_EXPECTS(capture.name != nullptr, "device UDF capture name must not be null");
RAFT_EXPECTS(capture.dtype != nullptr, "device UDF capture dtype must not be null");
RAFT_EXPECTS(capture.ndim >= 0, "device UDF capture ndim must be non-negative");
RAFT_EXPECTS(capture.ndim == 0 || capture.shape != nullptr,
"device UDF capture shape must not be null when ndim is non-zero");
RAFT_EXPECTS(capture.pointer != 0, "device UDF capture pointer must not be zero");
RAFT_EXPECTS((capture.flags & ~CUVS_UDF_CAPTURE_READONLY) == 0,
"device UDF capture has unsupported flags");

auto next = udf_capture{.name = std::string{capture.name},
.dtype = std::string{capture.dtype},
.device_id = capture.device_id,
.pointer = capture.pointer,
.readonly = (capture.flags & CUVS_UDF_CAPTURE_READONLY) != 0};

auto const ndim = static_cast<size_t>(capture.ndim);
if (ndim > 0) { next.shape.assign(capture.shape, capture.shape + ndim); }
if (capture.strides != nullptr) {
next.strides.assign(capture.strides, capture.strides + ndim);
}
out.captures.push_back(std::move(next));
}

return out;
}

} // namespace cuvs::jit
8 changes: 5 additions & 3 deletions cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@ struct StaticFatbinFragmentEntry final : FatbinFragmentEntry {
static const size_t length;
};

struct UDFFatbinFragment final : FatbinFragmentEntry {
struct UDFFatbinFragment final : FragmentEntry {
UDFFatbinFragment(std::string key, std::vector<uint8_t> bytes)
: key_(std::move(key)), bytes_(std::move(bytes))
{
}

const uint8_t* get_data() const override { return bytes_.data(); }
const uint8_t* get_data() const { return bytes_.data(); }

size_t get_length() const override { return bytes_.size(); }
size_t get_length() const { return bytes_.size(); }

bool add_to(nvJitLinkHandle& handle) const override;

const char* get_key() const override { return key_.c_str(); }

Expand Down
Loading
Loading