Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include "../../../core/nvtx.hpp"
#include "../../vpq_dataset.cuh"
#include "../../../preprocessing/quantize/vpq_build-ext.cuh"
#include "graph_core.cuh"

#include <raft/core/copy.cuh>
Expand Down Expand Up @@ -2279,8 +2279,7 @@ index<T, IdxT> build(
idx.update_dataset(
res,
// TODO: hardcoding codebook math to `half`, we can do runtime dispatching later
cuvs::neighbors::vpq_build<decltype(dataset), half, int64_t>(
res, *params.compression, dataset));
cuvs::preprocessing::quantize::pq::vpq_build(res, *params.compression, dataset));

return idx;
}
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -11,6 +11,9 @@
#include "search_plan.cuh"
#include "search_single_cta.cuh"

#include <raft/core/resource/custom_resource.hpp>
#include <raft/util/cache.hpp>

#include <cuvs/neighbors/common.hpp>

namespace cuvs::neighbors::cagra::detail {
Expand Down
176 changes: 82 additions & 94 deletions cpp/src/neighbors/detail/vpq_dataset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,41 @@ void process_and_fill_codes(
RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 16]", pq_bits);
}
}(pq_bits);
bool need_copy_to_device =
cuvs::spatial::knn::detail::utils::check_pointer_residency(dataset.data_handle()) ==
cuvs::spatial::knn::detail::utils::pointer_residency::host_only;
bool need_batching = n_rows > kReasonableMaxBatchSize;
auto launch_work = [&](auto& dataset_view, auto& labels_view, auto& codes_view) {
if (inline_vq_labels || (!vq_labels.empty() && !vq_centers.empty())) {
predict_vq<label_t>(res, dataset_view, vq_centers, labels_view);
}
dim3 blocks(
raft::div_rounding_up_safe<ix_t>(dataset_view.extent(0), kBlockSize / threads_per_vec), 1, 1);
kernel<<<blocks, threads, sharedMemorySize, stream>>>(codes_view,
dataset_view,
pq_centers,
vq_centers,
raft::make_const_mdspan(labels_view),
rows_in_shared_memory,
pq_bits,
inline_vq_labels);
RAFT_CUDA_TRY(cudaPeekAtLastError());
};
auto batch_labels = raft::make_device_vector<label_t, IdxT>(res, 0);
if (!need_batching && !need_copy_to_device) {
// No batching needed, launch the kernel directly
auto dataset_view = raft::make_device_matrix_view(dataset.data_handle(), n_rows, dim);
auto labels_view = raft::make_device_vector_view<label_t, IdxT>(nullptr, 0);
if (inline_vq_labels) {
batch_labels = raft::make_device_vector<label_t, IdxT>(res, dataset_view.extent(0));
labels_view = batch_labels.view();
} else if (!vq_labels.empty() && !vq_centers.empty()) {
labels_view = vq_labels;
}
launch_work(dataset_view, labels_view, codes);
return;
}

for (const auto& batch : cuvs::spatial::knn::detail::utils::batch_load_iterator(
dataset.data_handle(),
n_rows,
Expand All @@ -475,53 +510,20 @@ void process_and_fill_codes(
stream,
rmm::mr::get_current_device_resource())) {
auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim);
auto batch_labels = raft::make_device_vector<label_t, IdxT>(res, 0);
auto batch_labels_view = raft::make_device_vector_view<label_t, IdxT>(nullptr, 0);
if (inline_vq_labels) {
batch_labels = raft::make_device_vector<label_t, IdxT>(res, batch.size());
batch_labels_view = batch_labels.view();
predict_vq<label_t>(res, batch_view, vq_centers, batch_labels_view);
} else {
if (!vq_labels.empty() && !vq_centers.empty()) {
batch_labels_view = raft::make_device_vector_view<label_t, IdxT>(
vq_labels.data_handle() + batch.offset(), batch.size());
predict_vq<label_t>(res, batch_view, vq_centers, batch_labels_view);
}
} else if (!vq_labels.empty() && !vq_centers.empty()) {
batch_labels_view = raft::make_device_vector_view<label_t, IdxT>(
vq_labels.data_handle() + batch.offset(), batch.size());
}
dim3 blocks(raft::div_rounding_up_safe<ix_t>(n_rows, kBlockSize / threads_per_vec), 1, 1);
kernel<<<blocks, threads, sharedMemorySize, stream>>>(
raft::make_device_matrix_view<uint8_t, IdxT>(
codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen),
batch_view,
pq_centers,
vq_centers,
raft::make_const_mdspan(batch_labels_view),
rows_in_shared_memory,
pq_bits,
inline_vq_labels);
RAFT_CUDA_TRY(cudaPeekAtLastError());
auto batch_codes_view = raft::make_device_matrix_view<uint8_t, IdxT>(
codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen);
launch_work(batch_view, batch_labels_view, batch_codes_view);
}
}

template <typename NewMathT, typename OldMathT, typename IdxT>
auto vpq_convert_math_type(const raft::resources& res, vpq_dataset<OldMathT, IdxT>&& src)
-> vpq_dataset<NewMathT, IdxT>
{
auto vq_code_book = raft::make_device_mdarray<NewMathT>(res, src.vq_code_book.extents());
auto pq_code_book = raft::make_device_mdarray<NewMathT>(res, src.pq_code_book.extents());

raft::linalg::map(res,
vq_code_book.view(),
cuvs::spatial::knn::detail::utils::mapping<NewMathT>{},
raft::make_const_mdspan(src.vq_code_book.view()));
raft::linalg::map(res,
pq_code_book.view(),
cuvs::spatial::knn::detail::utils::mapping<NewMathT>{},
raft::make_const_mdspan(src.pq_code_book.view()));
return vpq_dataset<NewMathT, IdxT>{
std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)};
}

// Helper for operations using vectorized loads of raft::TxN_t
template <typename MathT, int VectorSize>
struct vec_op : raft::TxN_t<MathT, VectorSize> {
Expand Down Expand Up @@ -858,14 +860,40 @@ void process_and_fill_codes_subspaces(
}
}(pq_bits);

ix_t max_batch_size = std::min<ix_t>(n_rows, kReasonableMaxBatchSize);
auto copy_stream = raft::resource::get_cuda_stream(res); // Using the main stream by default
bool enable_prefetch = false;
if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL)) {
if (raft::resource::get_stream_pool_size(res) >= 1) {
enable_prefetch = true;
copy_stream = raft::resource::get_stream_from_stream_pool(res);
ix_t max_batch_size = std::min<ix_t>(n_rows, kReasonableMaxBatchSize);
auto copy_stream = raft::resource::get_cuda_stream(res); // Using the main stream by default
bool enable_prefetch_stream = false;
bool has_cuda_stream_pool_resource =
res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) &&
raft::resource::get_stream_pool_size(res) >= 1;
bool need_copy_to_device =
cuvs::spatial::knn::detail::utils::check_pointer_residency(dataset.data_handle()) ==
cuvs::spatial::knn::detail::utils::pointer_residency::host_only;
bool need_batching = n_rows > kReasonableMaxBatchSize;
auto launch_work = [&](auto& dataset_view, auto& labels_view, auto& codes_view) {
if (!vq_labels.empty() && !vq_centers.empty()) {
predict_vq<label_t>(res, dataset_view, vq_centers, labels_view);
}
dim3 blocks(
raft::div_rounding_up_safe<ix_t>(dataset_view.extent(0), kBlockSize / threads_per_vec), 1, 1);
kernel<<<blocks, threads, shared_memory_size, stream>>>(codes_view,
dataset_view,
pq_centers,
vq_centers,
raft::make_const_mdspan(labels_view),
pq_bits,
shared_memory_size > 0);
RAFT_CUDA_TRY(cudaPeekAtLastError());
};
if (!need_batching && !need_copy_to_device) {
// No batching and no copy to device needed, launch the kernel directly
auto dataset_view = raft::make_device_matrix_view(dataset.data_handle(), n_rows, dim);
launch_work(dataset_view, vq_labels, codes);
return;
}
if (has_cuda_stream_pool_resource && need_copy_to_device) {
enable_prefetch_stream = true;
copy_stream = raft::resource::get_stream_from_stream_pool(res);
}
auto vec_batches = cuvs::spatial::knn::detail::utils::batch_load_iterator(
dataset.data_handle(),
Expand All @@ -874,62 +902,22 @@ void process_and_fill_codes_subspaces(
max_batch_size,
copy_stream,
raft::resource::get_workspace_resource(res),
enable_prefetch);
enable_prefetch_stream);
vec_batches.prefetch_next_batch();
for (const auto& batch : vec_batches) {
auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim);
auto batch_labels = raft::make_device_vector_view<label_t, IdxT>(nullptr, 0);
if (!vq_labels.empty() && !vq_centers.empty()) {
batch_labels = raft::make_device_vector_view<label_t, IdxT>(
vq_labels.data_handle() + batch.offset(), batch.size());
predict_vq<label_t>(res, batch_view, vq_centers, batch_labels);
}
dim3 blocks(raft::div_rounding_up_safe<ix_t>(batch.size(), kBlockSize / threads_per_vec), 1, 1);
kernel<<<blocks, threads, shared_memory_size, stream>>>(
raft::make_device_matrix_view<uint8_t, IdxT>(
codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen),
batch_view,
pq_centers,
vq_centers,
raft::make_const_mdspan(batch_labels),
pq_bits,
shared_memory_size > 0);
RAFT_CUDA_TRY(cudaPeekAtLastError());
vec_batches.prefetch_next_batch();
raft::resource::sync_stream(res);
auto batch_codes_view = raft::make_device_matrix_view<uint8_t, IdxT>(
codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen);
launch_work(batch_view, batch_labels, batch_codes_view);
if (enable_prefetch_stream) {
vec_batches.prefetch_next_batch();
raft::resource::sync_stream(res);
}
}
}

template <typename DatasetT, typename MathT, typename IdxT>
auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset)
-> vpq_dataset<MathT, IdxT>
{
using label_t = uint32_t;
// Use a heuristic to impute missing parameters.
auto ps = fill_missing_params_heuristics(params, dataset);

// Train codes
auto vq_code_book = train_vq<MathT>(res, ps, dataset);
auto pq_code_book =
train_pq<MathT>(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view()));

// Encode dataset
const IdxT n_rows = dataset.extent(0);
const IdxT codes_rowlen = sizeof(label_t) * (1 + raft::div_rounding_up_safe<IdxT>(
ps.pq_dim * ps.pq_bits, 8 * sizeof(label_t)));

auto codes = raft::make_device_matrix<uint8_t, IdxT, raft::row_major>(res, n_rows, codes_rowlen);
process_and_fill_codes<MathT, IdxT>(res,
ps,
dataset,
raft::make_const_mdspan(pq_code_book.view()),
raft::make_const_mdspan(vq_code_book.view()),
raft::make_device_vector_view<label_t, IdxT>(nullptr, 0),
codes.view(),
true);

return vpq_dataset<MathT, IdxT>{
std::move(vq_code_book), std::move(pq_code_book), std::move(codes)};
}

} // namespace cuvs::neighbors::detail
Loading
Loading