From cf7e572a532f91ea0ed628c7b892ab033136e912 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 26 Jan 2026 08:43:02 -0800 Subject: [PATCH 1/8] Separate PQ and CAGRA-Q Signed-off-by: Mickael Ide --- .../neighbors/detail/cagra/cagra_build.cuh | 5 +- cpp/src/neighbors/detail/cagra/factory.cuh | 5 +- cpp/src/neighbors/detail/vpq_dataset.cuh | 23 +------ cpp/src/neighbors/vpq_dataset.cuh | 41 ------------ cpp/src/preprocessing/quantize/detail/pq.cuh | 63 +++++++++++++++++++ cpp/src/preprocessing/quantize/pq.cu | 21 +++++++ .../preprocessing/quantize/vpq_build-ext.cuh | 28 +++++++++ cpp/src/preprocessing/quantize/vpq_build.cuh | 57 +++++++++++++++++ 8 files changed, 177 insertions(+), 66 deletions(-) delete mode 100644 cpp/src/neighbors/vpq_dataset.cuh create mode 100644 cpp/src/preprocessing/quantize/vpq_build-ext.cuh create mode 100644 cpp/src/preprocessing/quantize/vpq_build.cuh diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 97d7bb1bac..ac923adb69 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -5,7 +5,7 @@ #pragma once #include "../../../core/nvtx.hpp" -#include "../../vpq_dataset.cuh" +#include "../../../preprocessing/quantize/vpq_build-ext.cuh" #include "graph_core.cuh" #include @@ -2281,8 +2281,7 @@ index build( idx.update_dataset( res, // TODO: hardcoding codebook math to `half`, we can do runtime dispatching later - cuvs::neighbors::vpq_build( - res, *params.compression, dataset)); + cuvs::preprocessing::quantize::pq::vpq_build(res, *params.compression, dataset)); return idx; } diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index a767d16530..75f236a65d 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -11,6 +11,9 @@ #include "search_plan.cuh" #include "search_single_cta.cuh" +#include +#include + #include namespace cuvs::neighbors::cagra::detail { diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index e609100a76..aa9c775966 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -503,25 +503,6 @@ void process_and_fill_codes( } } -template -auto vpq_convert_math_type(const raft::resources& res, vpq_dataset&& src) - -> vpq_dataset -{ - auto vq_code_book = raft::make_device_mdarray(res, src.vq_code_book.extents()); - auto pq_code_book = raft::make_device_mdarray(res, src.pq_code_book.extents()); - - raft::linalg::map(res, - vq_code_book.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.vq_code_book.view())); - raft::linalg::map(res, - pq_code_book.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.pq_code_book.view())); - return vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)}; -} - // Helper for operations using vectorized loads of raft::TxN_t template struct vec_op : raft::TxN_t { @@ -899,7 +880,7 @@ void process_and_fill_codes_subspaces( raft::resource::sync_stream(res); } } - +/* template auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) -> vpq_dataset @@ -930,6 +911,6 @@ auto vpq_build(const raft::resources& res, const vpq_params& params, const Datas return vpq_dataset{ std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; -} +}*/ } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/vpq_dataset.cuh b/cpp/src/neighbors/vpq_dataset.cuh deleted file mode 100644 index 34b011bb0b..0000000000 --- a/cpp/src/neighbors/vpq_dataset.cuh +++ /dev/null @@ -1,41 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include "detail/vpq_dataset.cuh" -#include - -#include - -namespace cuvs::neighbors { - -/** - * @brief Compress a dataset for use in CAGRA-Q search in place of the original data. - * - * @tparam DatasetT a row-major mdspan or mdarray (device or host). - * @tparam MathT a type of the codebook elements and internal math ops. - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] res - * @param[in] params VQ and PQ parameters for compressing the data - * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. - */ -template -auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> vpq_dataset - -{ - if constexpr (std::is_same_v) { - return detail::vpq_convert_math_type( - res, detail::vpq_build(res, params, dataset)); - } else { - return detail::vpq_build(res, params, dataset); - } -} - -} // namespace cuvs::neighbors diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index db7cc8d5c1..ee40c36cbe 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -328,4 +328,67 @@ void inverse_transform( out, quant.params_quantizer.use_subspaces); } + +template +auto vpq_convert_math_type(const raft::resources& res, + cuvs::neighbors::vpq_dataset&& src) + -> cuvs::neighbors::vpq_dataset +{ + auto vq_code_book = raft::make_device_mdarray(res, src.vq_code_book.extents()); + auto pq_code_book = raft::make_device_mdarray(res, src.pq_code_book.extents()); + + raft::linalg::map(res, + vq_code_book.view(), + cuvs::spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.vq_code_book.view())); + raft::linalg::map(res, + pq_code_book.view(), + cuvs::spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.pq_code_book.view())); + return cuvs::neighbors::vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)}; +} + +template +auto vpq_build(const raft::resources& res, + const cuvs::neighbors::vpq_params& params, + const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset +{ + using label_t = uint32_t; + // Use a heuristic to impute missing parameters. + auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset); + + // Train codes + auto vq_code_book = cuvs::neighbors::detail::train_vq(res, ps, dataset); + auto pq_code_book = cuvs::neighbors::detail::train_pq( + res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); + + // Encode dataset + const IdxT n_rows = dataset.extent(0); + const IdxT codes_rowlen = sizeof(label_t) * (1 + raft::div_rounding_up_safe( + ps.pq_dim * ps.pq_bits, 8 * sizeof(label_t))); + + auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); + cuvs::neighbors::detail::process_and_fill_codes( + res, + ps, + dataset, + raft::make_const_mdspan(pq_code_book.view()), + raft::make_const_mdspan(vq_code_book.view()), + raft::make_device_vector_view(nullptr, 0), + codes.view(), + true); + + return cuvs::neighbors::vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; +} + +template +auto vpq_build_half(const raft::resources& res, + const cuvs::neighbors::vpq_params& params, + const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset +{ + return vpq_convert_math_type( + res, vpq_build(res, params, dataset)); +} } // namespace cuvs::preprocessing::quantize::pq::detail diff --git a/cpp/src/preprocessing/quantize/pq.cu b/cpp/src/preprocessing/quantize/pq.cu index 4a381c59ca..761474bdf8 100644 --- a/cpp/src/preprocessing/quantize/pq.cu +++ b/cpp/src/preprocessing/quantize/pq.cu @@ -52,4 +52,25 @@ CUVS_INST_QUANTIZATION(float, uint8_t); #undef CUVS_INST_QUANTIZATION +#define CUVS_INST_VPQ_BUILD(T) \ + auto vpq_build(const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::host_matrix_view& dataset) \ + { \ + return detail::vpq_build_half(res, params, dataset); \ + } \ + auto vpq_build(const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::device_matrix_view& dataset) \ + { \ + return detail::vpq_build_half(res, params, dataset); \ + } + +CUVS_INST_VPQ_BUILD(float); +CUVS_INST_VPQ_BUILD(half); +CUVS_INST_VPQ_BUILD(int8_t); +CUVS_INST_VPQ_BUILD(uint8_t); + +#undef CUVS_INST_VPQ_BUILD + } // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/src/preprocessing/quantize/vpq_build-ext.cuh b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh new file mode 100644 index 0000000000..1745e53a33 --- /dev/null +++ b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh @@ -0,0 +1,28 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +namespace cuvs::preprocessing::quantize::pq { + +#define CUVS_INST_VPQ_BUILD(T) \ + cuvs::neighbors::vpq_dataset vpq_build( \ + const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::host_matrix_view& dataset); \ + cuvs::neighbors::vpq_dataset vpq_build( \ + const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::device_matrix_view& dataset); + +CUVS_INST_VPQ_BUILD(float); +CUVS_INST_VPQ_BUILD(half); +CUVS_INST_VPQ_BUILD(int8_t); +CUVS_INST_VPQ_BUILD(uint8_t); + +#undef CUVS_INST_VPQ_BUILD +} // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/src/preprocessing/quantize/vpq_build.cuh b/cpp/src/preprocessing/quantize/vpq_build.cuh new file mode 100644 index 0000000000..7296f44f44 --- /dev/null +++ b/cpp/src/preprocessing/quantize/vpq_build.cuh @@ -0,0 +1,57 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +namespace cuvs::preprocessing::quantize::pq { + +#define CUVS_INST_VPQ_BUILD(T) \ + cuvs::neighbors::vpq_dataset vpq_build( \ + const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::host_matrix_view& dataset); \ + cuvs::neighbors::vpq_dataset vpq_build( \ + const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::device_matrix_view& dataset); + +CUVS_INST_VPQ_BUILD(float); +CUVS_INST_VPQ_BUILD(half); +CUVS_INST_VPQ_BUILD(int8_t); +CUVS_INST_VPQ_BUILD(uint8_t); + +#undef CUVS_INST_VPQ_BUILD + +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +namespace cuvs::preprocessing::quantize::pq { + +#define CUVS_INST_VPQ_BUILD(T) \ + cuvs::neighbors::vpq_dataset vpq_build( \ + const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::host_matrix_view& dataset); \ + cuvs::neighbors::vpq_dataset vpq_build( \ + const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::device_matrix_view& dataset); + +CUVS_INST_VPQ_BUILD(float); +CUVS_INST_VPQ_BUILD(half); +CUVS_INST_VPQ_BUILD(int8_t); +CUVS_INST_VPQ_BUILD(uint8_t); + +#undef CUVS_INST_VPQ_BUILD + +} // namespace cuvs::preprocessing::quantize::pq From e7a98054a5076a2ad0cd14c6c98190d7a561a33a Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 27 Jan 2026 13:23:19 -0800 Subject: [PATCH 2/8] Use PQ API in SCANN code + Add simple cpp test Signed-off-by: Mickael Ide --- .../neighbors/scann/detail/scann_build.cuh | 131 ++++++------- .../neighbors/scann/detail/scann_quantize.cuh | 182 +----------------- cpp/tests/neighbors/ann_scann.cuh | 44 ++++- 3 files changed, 108 insertions(+), 249 deletions(-) diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 8902fc2051..1741154583 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -8,6 +8,7 @@ #include "../../detail/ann_utils.cuh" #include #include +#include #include #include @@ -157,45 +158,17 @@ index build( int dim_per_subspace = params.pq_dim; int num_clusters = 1 << params.pq_bits; - auto full_codebook = - raft::make_device_matrix(res, num_clusters * num_subspaces, dim_per_subspace); - - // Loop each subspace, training codebooks for each - for (int subspace = 0; subspace < num_subspaces; subspace++) { - int sub_dim_start = subspace * dim_per_subspace; - int sub_dim_end = (subspace + 1) * dim_per_subspace; - - auto sub_trainset = raft::make_device_matrix( - res, trainset_residuals.extent(0), (int64_t)dim_per_subspace); - raft::matrix::slice_coordinates avq_sub_coords( - 0, sub_dim_start, trainset_residuals.extent(0), sub_dim_end); - raft::matrix::slice( - res, raft::make_const_mdspan(trainset_residuals.view()), sub_trainset.view(), avq_sub_coords); - - // Set up quantization bits and params - cuvs::neighbors::vpq_params pq_params; - pq_params.pq_bits = params.pq_bits; - // For VPQ, pq_dim is the number of subspaces, not the dimension of the subspaces - pq_params.pq_dim = 1; - // We handle sampling/training set construction above, so use the full set in VPQ - pq_params.pq_kmeans_trainset_fraction = 1.0; - pq_params.kmeans_n_iters = params.pq_train_iters; - - // Create pq codebook for this subspace - auto sub_pq_codebook = - create_pq_codebook(res, raft::make_const_mdspan(sub_trainset.view()), pq_params); - - raft::copy(full_codebook.data_handle() + (subspace * sub_pq_codebook.size()), - sub_pq_codebook.data_handle(), - sub_pq_codebook.size(), - stream); - } - raft::resource::sync_stream(res); + cuvs::preprocessing::quantize::pq::params pq_build_params; + pq_build_params.pq_bits = params.pq_bits; + pq_build_params.pq_dim = num_subspaces; + pq_build_params.use_subspaces = true; + pq_build_params.use_vq = false; // We already computed residuals + pq_build_params.kmeans_n_iters = params.pq_train_iters; + pq_build_params.max_train_points_per_pq_code = pq_n_rows_train / num_clusters; + pq_build_params.pq_kmeans_type = cuvs::cluster::kmeans::kmeans_type::KMeansBalanced; - // Set up quantization bits and params - cuvs::neighbors::vpq_params pq_params; - pq_params.pq_bits = params.pq_bits; - pq_params.pq_dim = dataset.extent(1) / params.pq_dim; + auto pq_quantizer = cuvs::preprocessing::quantize::pq::build( + res, pq_build_params, raft::make_const_mdspan(trainset_residuals.view())); dataset_vec_batches.reset(); dataset_vec_batches.prefetch_next_batch(); @@ -226,9 +199,11 @@ index build( batch_soar_labels_view, params.soar_lambda); - // Compute and quantize residuals - auto avq_quant = quantize_residuals( - res, raft::make_const_mdspan(avq_residuals.view()), full_codebook.view(), pq_params); + // Compute and quantize residuals using the public PQ API + int64_t codes_dim = cuvs::preprocessing::quantize::pq::get_quantized_dim(pq_build_params); + auto avq_quant = raft::make_device_matrix(res, batch.size(), codes_dim); + cuvs::preprocessing::quantize::pq::transform( + res, pq_quantizer, raft::make_const_mdspan(avq_residuals.view()), avq_quant.view()); // Compute and quantize SOAR residuals auto soar_residuals = @@ -237,25 +212,49 @@ index build( raft::make_const_mdspan(centroids_view), raft::make_const_mdspan(batch_soar_labels_view)); - auto soar_quant = quantize_residuals( - res, raft::make_const_mdspan(soar_residuals.view()), full_codebook.view(), pq_params); + auto soar_quant = raft::make_device_matrix(res, batch.size(), codes_dim); + cuvs::preprocessing::quantize::pq::transform( + res, pq_quantizer, raft::make_const_mdspan(soar_residuals.view()), soar_quant.view()); // unpack codes - auto quantized_residuals = - raft::make_device_matrix(res, batch.size(), num_subspaces); - auto quantized_soar_residuals = - raft::make_device_matrix(res, batch.size(), num_subspaces); - - unpack_codes(res, - quantized_residuals.view(), - raft::make_const_mdspan(avq_quant.view()), - params.pq_bits, - num_subspaces); - unpack_codes(res, - quantized_soar_residuals.view(), - raft::make_const_mdspan(soar_quant.view()), - params.pq_bits, - num_subspaces); + if (pq_quantizer.params_quantizer.pq_bits == 8) { + // Copy unpacked codes to host + // TODO (rmaschal): these copies are blocking and not overlapped + raft::copy(idx.quantized_residuals().data_handle() + batch.offset() * num_subspaces, + avq_quant.data_handle(), + avq_quant.size(), + stream); + + raft::copy(idx.quantized_soar_residuals().data_handle() + batch.offset() * num_subspaces, + soar_quant.data_handle(), + soar_quant.size(), + stream); + } else { + auto quantized_residuals = + raft::make_device_matrix(res, batch.size(), num_subspaces); + auto quantized_soar_residuals = + raft::make_device_matrix(res, batch.size(), num_subspaces); + + unpack_codes(res, + quantized_residuals.view(), + raft::make_const_mdspan(avq_quant.view()), + params.pq_bits, + num_subspaces); + unpack_codes(res, + quantized_soar_residuals.view(), + raft::make_const_mdspan(soar_quant.view()), + params.pq_bits, + num_subspaces); + raft::copy(idx.quantized_residuals().data_handle() + batch.offset() * num_subspaces, + quantized_residuals.data_handle(), + quantized_residuals.size(), + stream); + + raft::copy(idx.quantized_soar_residuals().data_handle() + batch.offset() * num_subspaces, + quantized_soar_residuals.data_handle(), + quantized_soar_residuals.size(), + stream); + } // quantize dataset to bfloat16, if enabled. Similar to SOAR, quantization // is performed in this loop to improve locality @@ -270,18 +269,6 @@ index build( // Prefetch next batch dataset_vec_batches.prefetch_next_batch(); - // Copy unpacked codes to host - // TODO (rmaschal): these copies are blocking and not overlapped - raft::copy(idx.quantized_residuals().data_handle() + batch.offset() * num_subspaces, - quantized_residuals.data_handle(), - quantized_residuals.size(), - stream); - - raft::copy(idx.quantized_soar_residuals().data_handle() + batch.offset() * num_subspaces, - quantized_soar_residuals.data_handle(), - quantized_soar_residuals.size(), - stream); - if (params.reordering_bf16) { raft::copy(idx.bf16_dataset().data_handle() + batch.offset() * dim, bf16_dataset.data_handle(), @@ -296,7 +283,7 @@ index build( // Codebooks from VPQ have the shape [subspace idx, subspace dim, code] // This converts the codebook into matrix format for easy interoperability // with open-source ScaNN search - auto full_codebook_view = full_codebook.view(); + auto full_codebook_view = pq_quantizer.vpq_codebooks.pq_code_book.view(); raft::linalg::map_offset( res, diff --git a/cpp/src/neighbors/scann/detail/scann_quantize.cuh b/cpp/src/neighbors/scann/detail/scann_quantize.cuh index d00abad6d9..79f6e1bb05 100644 --- a/cpp/src/neighbors/scann/detail/scann_quantize.cuh +++ b/cpp/src/neighbors/scann/detail/scann_quantize.cuh @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "../../detail/vpq_dataset.cuh" -#include "../../ivf_pq/ivf_pq_codepacking.cuh" #include #include #include @@ -19,125 +17,6 @@ namespace cuvs::neighbors::experimental::scann::detail { /** Fix the internal indexing type to avoid integer underflows/overflows */ using ix_t = int64_t; -template -__launch_bounds__(BlockSize) RAFT_KERNEL process_and_fill_codes_subspaces_kernel( - raft::device_matrix_view out_codes, - raft::device_matrix_view dataset, - raft::device_matrix_view vq_centers, - raft::device_vector_view vq_labels, - raft::device_matrix_view pq_centers) -{ - constexpr uint32_t kSubWarpSize = std::min(raft::WarpSize, 1u << PqBits); - using subwarp_align = raft::Pow2; - const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); - if (row_ix >= out_codes.extent(0)) { return; } - - const uint32_t pq_dim = raft::div_rounding_up_unsafe(vq_centers.extent(1), pq_centers.extent(1)); - - const uint32_t lane_id = raft::Pow2::mod(threadIdx.x); - const LabelT vq_label = vq_labels(row_ix); - - // write label - auto* out_label_ptr = reinterpret_cast(&out_codes(row_ix, 0)); - if (lane_id == 0) { *out_label_ptr = vq_label; } - - auto* out_codes_ptr = reinterpret_cast(out_label_ptr + 1); - cuvs::neighbors::ivf_pq::detail::bitfield_view_t code_view{out_codes_ptr}; - for (uint32_t j = 0; j < pq_dim; j++) { - // find PQ label - int subspace_offset = j * pq_centers.extent(1) * (1 << PqBits); - auto pq_subspace_view = raft::make_device_matrix_view( - pq_centers.data_handle() + subspace_offset, (uint32_t)(1 << PqBits), pq_centers.extent(1)); - auto pq_centers_smem = - raft::make_device_matrix_view(nullptr, 0, 0); - uint8_t code = cuvs::neighbors::detail::compute_code( - dataset, vq_centers, pq_centers_smem, pq_subspace_view, row_ix, j, vq_label); - // TODO: this writes in global memory one byte per warp, which is very slow. - // It's better to keep the codes in the shared memory or registers and dump them at once. - if (lane_id == 0) { code_view[j] = code; } - } -} - -template -auto process_and_fill_codes_subspaces( - const raft::resources& res, - const vpq_params& params, - const DatasetT& dataset, - raft::device_matrix_view vq_centers, - raft::device_matrix_view pq_centers) - -> raft::device_matrix -{ - using data_t = typename DatasetT::value_type; - using cdataset_t = vpq_dataset; - using label_t = uint32_t; - - const ix_t n_rows = dataset.extent(0); - const ix_t dim = dataset.extent(1); - const ix_t pq_dim = params.pq_dim; - const ix_t pq_bits = params.pq_bits; - const ix_t pq_n_centers = ix_t{1} << pq_bits; - // NB: codes must be aligned at least to sizeof(label_t) to be able to read labels. - const ix_t codes_rowlen = - sizeof(label_t) * (1 + raft::div_rounding_up_safe(pq_dim * pq_bits, 8 * sizeof(label_t))); - - auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); - - auto stream = raft::resource::get_cuda_stream(res); - - // TODO: with scaling workspace we could choose the batch size dynamically - constexpr ix_t kBlockSize = 256; - const ix_t threads_per_vec = std::min(raft::WarpSize, pq_n_centers); - dim3 threads(kBlockSize, 1, 1); - - auto kernel = [](uint32_t pq_bits) { - switch (pq_bits) { - case 4: - return process_and_fill_codes_subspaces_kernel; - case 8: - return process_and_fill_codes_subspaces_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be 4 or 8", pq_bits); - } - }(pq_bits); - - auto labels = raft::make_device_vector(res, dataset.extent(0)); - cuvs::neighbors::detail::predict_vq(res, dataset, vq_centers, labels.view()); - - dim3 blocks(raft::div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); - - kernel<<>>( - raft::make_device_matrix_view(codes.data_handle(), n_rows, codes_rowlen), - dataset, - vq_centers, - raft::make_const_mdspan(labels.view()), - pq_centers); - - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - return codes; -} - -template -auto create_pq_codebook(raft::resources const& res, - raft::device_matrix_view residuals, - cuvs::neighbors::vpq_params ps) - -> raft::device_matrix -{ - // Create codebooks (vq initialized to 0s since we don't need here) - auto vq_code_book = - raft::make_device_matrix(res, 1, residuals.extent(1)); - raft::linalg::map_offset(res, vq_code_book.view(), [] __device__(size_t i) { return 0; }); - - auto pq_code_book = cuvs::neighbors::detail::train_pq( - res, ps, residuals, raft::make_const_mdspan(vq_code_book.view())); - - return pq_code_book; -} - /** * @brief Subtract cluster center coordinates from each dataset vector. * @@ -173,64 +52,19 @@ auto compute_residuals(raft::resources const& res, return residuals; } -/**} - * @brief Generate PQ codes for residual vectors using codebook - * - * For each subspace, minimize L2 norm between residual vectors and - * PQ centers to generate codes for residual vectors - * - * @tparam T - * @tparam IdxT - * @tparam LabelT - * @param res raft resources - * @param residuals the residual vectors we're quantizing, size [n_rows, dim] - * @param pq_codebook the codebook of PQ centers size [dim, 1 << pq_bits] - * @oaran ps parameters used with vpq_dataset for pq quantization - * @return device matrix with (packed) codes from vpq, size [n_rows, 1 +ceil((dim / pq_dim * - * pq_bits) /( 8 * sizeof(LabelT)))] - */ -template -auto quantize_residuals(raft::resources const& res, - raft::device_matrix_view residuals, - raft::device_matrix_view pq_codebook, - cuvs::neighbors::vpq_params ps) - -> raft::device_matrix -{ - auto dim = residuals.extent(1); - - // Using a single 0 vector for the vq_codebook, since we already have - // vq centers and computed residuals w.r.t those centers - auto vq_codebook = raft::make_device_matrix(res, 1, dim); - - RAFT_CUDA_TRY(cudaMemsetAsync(vq_codebook.data_handle(), - 0, - vq_codebook.size() * sizeof(T), - raft::resource::get_cuda_stream(res))); - - auto codes = process_and_fill_codes_subspaces( - res, ps, residuals, raft::make_const_mdspan(vq_codebook.view()), pq_codebook); - - return codes; -} - /** * @brief Unpack VPQ codes into 1-byte per code * - * VPQ gives codes in a "packed" form. The first 4 bytes give the code for - * vector quantization, and the remaining bytes the codes for subspace product - * quantization. In the case of 4 bit PQ, each byte stores codes for 2 subspaces - * in a packed form. + * VPQ gives codes in a "packed" form. In the case of 4 bit PQ, each byte stores + * codes for 2 subspaces in a packed form. * - * This function unpacks the codes by discarding the VQ code (which we don't need, - * since we use VPQ only for residual quantization) and (in the case of 4-bit PQ) - * unpackes the subspace codes into one byte each. This is for interoperability - * with open source ScaNN, which doesn't pack codes + * This function unpacks the subspace codes into one byte each. This is for + * interoperability with open source ScaNN, which doesn't pack codes * * @tparam IdxT * @param res raft resources * @param unpacked_codes_view matrix of unpacked codes, size [n_rows, dim / pq_dim] - * @param codes_view packed codes from vpq, size [n_rows, 1 +ceil((dim / pq_dim * pq_bits) /( 8 * - * sizeof(LabelT)))] + * @param codes_view packed codes from vpq, size [n_rows, ceil((dim / pq_dim * pq_bits) / 8)] * @param pq_bits number of bits used for PQ * @param num_subspaces the number of pq_subspaces (dim / pq_dim) */ @@ -246,7 +80,7 @@ void unpack_codes(raft::resources const& res, res, unpacked_codes_view, [codes_view, num_subspaces] __device__(size_t i) { int64_t row_idx = i / num_subspaces; int64_t subspace_idx = i % num_subspaces; - int64_t packed_subspace_idx = 4 + subspace_idx / 2; + int64_t packed_subspace_idx = subspace_idx / 2; uint8_t mask = subspace_idx % 2; uint8_t packed_labels = codes_view(row_idx, packed_subspace_idx); @@ -255,10 +89,6 @@ void unpack_codes(raft::resources const& res, return (mask)*first + (1 - mask) * second; }); - - } else { - raft::matrix::slice_coordinates coords(0, 4, codes_view.extent(0), 4 + num_subspaces); - raft::matrix::slice(res, raft::make_const_mdspan(codes_view), unpacked_codes_view, coords); } } diff --git a/cpp/tests/neighbors/ann_scann.cuh b/cpp/tests/neighbors/ann_scann.cuh index d8bc27fed7..7ab5966687 100644 --- a/cpp/tests/neighbors/ann_scann.cuh +++ b/cpp/tests/neighbors/ann_scann.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -123,6 +123,48 @@ class scann_test : public ::testing::TestWithParam { IdxT expected_bf16_size = ps.index_params.reordering_bf16 ? ps.dim * ps.num_db_vecs : 0; ASSERT_EQ(index.bf16_dataset().size(), expected_bf16_size); + check_code_validity(index, num_subspaces, num_pq_clusters); + } + + void check_code_validity(const index& idx, int num_subspaces, int num_pq_clusters) + { + auto quant_res_host = + raft::make_host_matrix(handle_, ps.num_db_vecs, num_subspaces); + auto quant_soar_host = + raft::make_host_matrix(handle_, ps.num_db_vecs, num_subspaces); + + raft::copy(quant_res_host.data_handle(), + idx.quantized_residuals().data_handle(), + idx.quantized_residuals().size(), + stream_); + raft::copy(quant_soar_host.data_handle(), + idx.quantized_soar_residuals().data_handle(), + idx.quantized_soar_residuals().size(), + stream_); + raft::resource::sync_stream(handle_); + + bool all_zeros = true; + bool has_nan = false; + // Check all codes are in valid range [0, num_pq_clusters] + auto n_vecs_to_check = std::min(ps.num_db_vecs, 50u); + for (IdxT i = 0; i < n_vecs_to_check * num_subspaces; i++) { + if (quant_res_host.data_handle()[i] != 0) { all_zeros = false; } + if (quant_soar_host.data_handle()[i] != 0) { all_zeros = false; } + if (std::isnan(quant_res_host.data_handle()[i])) { + has_nan = true; + break; + } + if (std::isnan(quant_soar_host.data_handle()[i])) { + has_nan = true; + break; + } + ASSERT_LT(quant_res_host.data_handle()[i], num_pq_clusters) + << "AVQ quantized code out of range at position " << i; + ASSERT_LT(quant_soar_host.data_handle()[i], num_pq_clusters) + << "SOAR quantized code out of range at position " << i; + } + ASSERT_FALSE(all_zeros) << "Quantized output contains all zeros"; + ASSERT_FALSE(has_nan) << "Quantized output contains NaN values"; } void SetUp() override // NOLINT From e4a6798d54ab989cd5af8662300047747e7b64e4 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 27 Jan 2026 13:45:01 -0800 Subject: [PATCH 3/8] Clean up for PR Signed-off-by: Mickael Ide --- cpp/src/neighbors/detail/vpq_dataset.cuh | 33 ------------ cpp/src/preprocessing/quantize/vpq_build.cuh | 57 -------------------- 2 files changed, 90 deletions(-) delete mode 100644 cpp/src/preprocessing/quantize/vpq_build.cuh diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index aa9c775966..19ab5ebc72 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -880,37 +880,4 @@ void process_and_fill_codes_subspaces( raft::resource::sync_stream(res); } } -/* -template -auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> vpq_dataset -{ - using label_t = uint32_t; - // Use a heuristic to impute missing parameters. - auto ps = fill_missing_params_heuristics(params, dataset); - - // Train codes - auto vq_code_book = train_vq(res, ps, dataset); - auto pq_code_book = - train_pq(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); - - // Encode dataset - const IdxT n_rows = dataset.extent(0); - const IdxT codes_rowlen = sizeof(label_t) * (1 + raft::div_rounding_up_safe( - ps.pq_dim * ps.pq_bits, 8 * sizeof(label_t))); - - auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); - process_and_fill_codes(res, - ps, - dataset, - raft::make_const_mdspan(pq_code_book.view()), - raft::make_const_mdspan(vq_code_book.view()), - raft::make_device_vector_view(nullptr, 0), - codes.view(), - true); - - return vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; -}*/ - } // namespace cuvs::neighbors::detail diff --git a/cpp/src/preprocessing/quantize/vpq_build.cuh b/cpp/src/preprocessing/quantize/vpq_build.cuh deleted file mode 100644 index 7296f44f44..0000000000 --- a/cpp/src/preprocessing/quantize/vpq_build.cuh +++ /dev/null @@ -1,57 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include -#include - -namespace cuvs::preprocessing::quantize::pq { - -#define CUVS_INST_VPQ_BUILD(T) \ - cuvs::neighbors::vpq_dataset vpq_build( \ - const raft::resources& res, \ - const cuvs::neighbors::vpq_params& params, \ - const raft::host_matrix_view& dataset); \ - cuvs::neighbors::vpq_dataset vpq_build( \ - const raft::resources& res, \ - const cuvs::neighbors::vpq_params& params, \ - const raft::device_matrix_view& dataset); - -CUVS_INST_VPQ_BUILD(float); -CUVS_INST_VPQ_BUILD(half); -CUVS_INST_VPQ_BUILD(int8_t); -CUVS_INST_VPQ_BUILD(uint8_t); - -#undef CUVS_INST_VPQ_BUILD - -/* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include -#include - -namespace cuvs::preprocessing::quantize::pq { - -#define CUVS_INST_VPQ_BUILD(T) \ - cuvs::neighbors::vpq_dataset vpq_build( \ - const raft::resources& res, \ - const cuvs::neighbors::vpq_params& params, \ - const raft::host_matrix_view& dataset); \ - cuvs::neighbors::vpq_dataset vpq_build( \ - const raft::resources& res, \ - const cuvs::neighbors::vpq_params& params, \ - const raft::device_matrix_view& dataset); - -CUVS_INST_VPQ_BUILD(float); -CUVS_INST_VPQ_BUILD(half); -CUVS_INST_VPQ_BUILD(int8_t); -CUVS_INST_VPQ_BUILD(uint8_t); - -#undef CUVS_INST_VPQ_BUILD - -} // namespace cuvs::preprocessing::quantize::pq From c057566401174cbd04e30119d4036cab2bb32896 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 4 Feb 2026 09:02:59 -0800 Subject: [PATCH 4/8] Improve prefetch logic Signed-off-by: Mickael Ide --- cpp/src/neighbors/detail/vpq_dataset.cuh | 28 +++++++++++-------- .../neighbors/scann/detail/scann_build.cuh | 12 +++----- cpp/src/preprocessing/quantize/detail/pq.cuh | 27 ++++++++++-------- 3 files changed, 37 insertions(+), 30 deletions(-) diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index 19ab5ebc72..133db93b12 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -839,14 +839,18 @@ void process_and_fill_codes_subspaces( } }(pq_bits); - ix_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); - auto copy_stream = raft::resource::get_cuda_stream(res); // Using the main stream by default - bool enable_prefetch = false; - if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL)) { - if (raft::resource::get_stream_pool_size(res) >= 1) { - enable_prefetch = true; - copy_stream = raft::resource::get_stream_from_stream_pool(res); - } + ix_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); + auto copy_stream = raft::resource::get_cuda_stream(res); // Using the main stream by default + bool enable_prefetch_stream = false; + bool has_cuda_stream_pool_resource = + res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && + raft::resource::get_stream_pool_size(res) >= 1; + bool need_copy_to_device = + cuvs::spatial::knn::detail::utils::check_pointer_residency(dataset.data_handle()) == + cuvs::spatial::knn::detail::utils::pointer_residency::host_only; + if (has_cuda_stream_pool_resource && need_copy_to_device) { + enable_prefetch_stream = true; + copy_stream = raft::resource::get_stream_from_stream_pool(res); } auto vec_batches = cuvs::spatial::knn::detail::utils::batch_load_iterator( dataset.data_handle(), @@ -855,7 +859,7 @@ void process_and_fill_codes_subspaces( max_batch_size, copy_stream, raft::resource::get_workspace_resource(res), - enable_prefetch); + enable_prefetch_stream); vec_batches.prefetch_next_batch(); for (const auto& batch : vec_batches) { auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); @@ -876,8 +880,10 @@ void process_and_fill_codes_subspaces( pq_bits, shared_memory_size > 0); RAFT_CUDA_TRY(cudaPeekAtLastError()); - vec_batches.prefetch_next_batch(); - raft::resource::sync_stream(res); + if (enable_prefetch_stream) { + vec_batches.prefetch_next_batch(); + raft::resource::sync_stream(res); + } } } } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 1741154583..2a480f4873 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -216,6 +216,8 @@ index build( cuvs::preprocessing::quantize::pq::transform( res, pq_quantizer, raft::make_const_mdspan(soar_residuals.view()), soar_quant.view()); + // Prefetch next batch + dataset_vec_batches.prefetch_next_batch(); // unpack codes if (pq_quantizer.params_quantizer.pq_bits == 8) { // Copy unpacked codes to host @@ -259,17 +261,11 @@ index build( // quantize dataset to bfloat16, if enabled. Similar to SOAR, quantization // is performed in this loop to improve locality // TODO (rmaschal): Might be more efficient to do on CPU, to avoid DtoH copy - auto bf16_dataset = raft::make_device_matrix(res, batch_view.extent(0), dim); - if (params.reordering_bf16) { + auto bf16_dataset = + raft::make_device_matrix(res, batch_view.extent(0), dim); quantize_bfloat16( res, batch_view, bf16_dataset.view(), params.reordering_noise_shaping_threshold); - } - - // Prefetch next batch - dataset_vec_batches.prefetch_next_batch(); - - if (params.reordering_bf16) { raft::copy(idx.bf16_dataset().data_handle() + batch.offset() * dim, bf16_dataset.data_handle(), bf16_dataset.size(), diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index ee40c36cbe..a3d6dc5774 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -330,23 +330,23 @@ void inverse_transform( } template -auto vpq_convert_math_type(const raft::resources& res, - cuvs::neighbors::vpq_dataset&& src) - -> cuvs::neighbors::vpq_dataset +void vpq_convert_math_type(const raft::resources& res, + const cuvs::neighbors::vpq_dataset& src, + cuvs::neighbors::vpq_dataset& dst) { - auto vq_code_book = raft::make_device_mdarray(res, src.vq_code_book.extents()); - auto pq_code_book = raft::make_device_mdarray(res, src.pq_code_book.extents()); + /*auto vq_code_book = raft::make_device_mdarray(res, src.vq_code_book.extents()); + auto pq_code_book = raft::make_device_mdarray(res, src.pq_code_book.extents());*/ raft::linalg::map(res, - vq_code_book.view(), + dst.vq_code_book.view(), cuvs::spatial::knn::detail::utils::mapping{}, raft::make_const_mdspan(src.vq_code_book.view())); raft::linalg::map(res, - pq_code_book.view(), + dst.pq_code_book.view(), cuvs::spatial::knn::detail::utils::mapping{}, raft::make_const_mdspan(src.pq_code_book.view())); - return cuvs::neighbors::vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)}; + /*return cuvs::neighbors::vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)};*/ } template @@ -388,7 +388,12 @@ auto vpq_build_half(const raft::resources& res, const cuvs::neighbors::vpq_params& params, const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset { - return vpq_convert_math_type( - res, vpq_build(res, params, dataset)); + auto old_type = vpq_build(res, params, dataset); + auto new_type = cuvs::neighbors::vpq_dataset{ + raft::make_device_mdarray(res, old_type.vq_code_book.extents()), + raft::make_device_mdarray(res, old_type.pq_code_book.extents()), + std::move(old_type.data)}; + vpq_convert_math_type(res, old_type, new_type); + return new_type; } } // namespace cuvs::preprocessing::quantize::pq::detail From bb93c96a4d29869590ceb6a88f1427530f53005d Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 4 Feb 2026 09:06:38 -0800 Subject: [PATCH 5/8] Remove comment Signed-off-by: Mickael Ide --- cpp/src/preprocessing/quantize/detail/pq.cuh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index a3d6dc5774..e04681b974 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -334,9 +334,6 @@ void vpq_convert_math_type(const raft::resources& res, const cuvs::neighbors::vpq_dataset& src, cuvs::neighbors::vpq_dataset& dst) { - /*auto vq_code_book = raft::make_device_mdarray(res, src.vq_code_book.extents()); - auto pq_code_book = raft::make_device_mdarray(res, src.pq_code_book.extents());*/ - raft::linalg::map(res, dst.vq_code_book.view(), cuvs::spatial::knn::detail::utils::mapping{}, @@ -345,8 +342,6 @@ void vpq_convert_math_type(const raft::resources& res, dst.pq_code_book.view(), cuvs::spatial::knn::detail::utils::mapping{}, raft::make_const_mdspan(src.pq_code_book.view())); - /*return cuvs::neighbors::vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)};*/ } template From 4102a42658a6b098a1db194fdc8dbabf6a0c2378 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 9 Feb 2026 09:16:59 -0800 Subject: [PATCH 6/8] Add a path to skip batching in PQ Signed-off-by: Mickael Ide --- cpp/src/neighbors/detail/vpq_dataset.cuh | 98 ++++++++++++++++-------- 1 file changed, 66 insertions(+), 32 deletions(-) diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index 133db93b12..fc114dd215 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -467,6 +467,41 @@ void process_and_fill_codes( RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 16]", pq_bits); } }(pq_bits); + bool need_copy_to_device = + cuvs::spatial::knn::detail::utils::check_pointer_residency(dataset.data_handle()) == + cuvs::spatial::knn::detail::utils::pointer_residency::host_only; + bool need_batching = n_rows > kReasonableMaxBatchSize; + auto launch_work = [&](auto& dataset_view, auto& labels_view, auto& codes_view) { + if (inline_vq_labels || (!vq_labels.empty() && !vq_centers.empty())) { + predict_vq(res, dataset_view, vq_centers, labels_view); + } + dim3 blocks( + raft::div_rounding_up_safe(dataset_view.extent(0), kBlockSize / threads_per_vec), 1, 1); + kernel<<>>(codes_view, + dataset_view, + pq_centers, + vq_centers, + raft::make_const_mdspan(labels_view), + rows_in_shared_memory, + pq_bits, + inline_vq_labels); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + }; + auto batch_labels = raft::make_device_vector(res, 0); + if (!need_batching && !need_copy_to_device) { + // No batching needed, launch the kernel directly + auto dataset_view = raft::make_device_matrix_view(dataset.data_handle(), n_rows, dim); + auto labels_view = raft::make_device_vector_view(nullptr, 0); + if (inline_vq_labels) { + batch_labels = raft::make_device_vector(res, dataset_view.extent(0)); + labels_view = batch_labels.view(); + } else if (!vq_labels.empty() && !vq_centers.empty()) { + labels_view = vq_labels; + } + launch_work(dataset_view, labels_view, codes); + return; + } + for (const auto& batch : cuvs::spatial::knn::detail::utils::batch_load_iterator( dataset.data_handle(), n_rows, @@ -475,31 +510,17 @@ void process_and_fill_codes( stream, rmm::mr::get_current_device_resource())) { auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); - auto batch_labels = raft::make_device_vector(res, 0); auto batch_labels_view = raft::make_device_vector_view(nullptr, 0); if (inline_vq_labels) { batch_labels = raft::make_device_vector(res, batch.size()); batch_labels_view = batch_labels.view(); - predict_vq(res, batch_view, vq_centers, batch_labels_view); - } else { - if (!vq_labels.empty() && !vq_centers.empty()) { - batch_labels_view = raft::make_device_vector_view( - vq_labels.data_handle() + batch.offset(), batch.size()); - predict_vq(res, batch_view, vq_centers, batch_labels_view); - } + } else if (!vq_labels.empty() && !vq_centers.empty()) { + batch_labels_view = raft::make_device_vector_view( + vq_labels.data_handle() + batch.offset(), batch.size()); } - dim3 blocks(raft::div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); - kernel<<>>( - raft::make_device_matrix_view( - codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen), - batch_view, - pq_centers, - vq_centers, - raft::make_const_mdspan(batch_labels_view), - rows_in_shared_memory, - pq_bits, - inline_vq_labels); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + auto batch_codes_view = raft::make_device_matrix_view( + codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen); + launch_work(batch_view, batch_labels_view, batch_codes_view); } } @@ -848,6 +869,28 @@ void process_and_fill_codes_subspaces( bool need_copy_to_device = cuvs::spatial::knn::detail::utils::check_pointer_residency(dataset.data_handle()) == cuvs::spatial::knn::detail::utils::pointer_residency::host_only; + bool need_batching = n_rows > kReasonableMaxBatchSize; + auto launch_work = [&](auto& dataset_view, auto& labels_view, auto& codes_view) { + if (!vq_labels.empty() && !vq_centers.empty()) { + predict_vq(res, dataset_view, vq_centers, labels_view); + } + dim3 blocks( + raft::div_rounding_up_safe(dataset_view.extent(0), kBlockSize / threads_per_vec), 1, 1); + kernel<<>>(codes_view, + dataset_view, + pq_centers, + vq_centers, + raft::make_const_mdspan(labels_view), + pq_bits, + shared_memory_size > 0); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + }; + if (!need_batching && !need_copy_to_device) { + // No batching and no copy to device needed, launch the kernel directly + auto dataset_view = raft::make_device_matrix_view(dataset.data_handle(), n_rows, dim); + launch_work(dataset_view, vq_labels, codes); + return; + } if (has_cuda_stream_pool_resource && need_copy_to_device) { enable_prefetch_stream = true; copy_stream = raft::resource::get_stream_from_stream_pool(res); @@ -867,19 +910,10 @@ void process_and_fill_codes_subspaces( if (!vq_labels.empty() && !vq_centers.empty()) { batch_labels = raft::make_device_vector_view( vq_labels.data_handle() + batch.offset(), batch.size()); - predict_vq(res, batch_view, vq_centers, batch_labels); } - dim3 blocks(raft::div_rounding_up_safe(batch.size(), kBlockSize / threads_per_vec), 1, 1); - kernel<<>>( - raft::make_device_matrix_view( - codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen), - batch_view, - pq_centers, - vq_centers, - raft::make_const_mdspan(batch_labels), - pq_bits, - shared_memory_size > 0); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + auto batch_codes_view = raft::make_device_matrix_view( + codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen); + launch_work(batch_view, batch_labels, batch_codes_view); if (enable_prefetch_stream) { vec_batches.prefetch_next_batch(); raft::resource::sync_stream(res); From cf3eec1c4190a3075e770ef214d60adf473d4e1f Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 24 Feb 2026 07:25:06 -0800 Subject: [PATCH 7/8] Fix tests Signed-off-by: Mickael Ide --- cpp/tests/neighbors/ann_scann.cuh | 19 ++----------------- .../preprocessing/product_quantization.cu | 13 ++++--------- 2 files changed, 6 insertions(+), 26 deletions(-) diff --git a/cpp/tests/neighbors/ann_scann.cuh b/cpp/tests/neighbors/ann_scann.cuh index 7ab5966687..5d1bcd2068 100644 --- a/cpp/tests/neighbors/ann_scann.cuh +++ b/cpp/tests/neighbors/ann_scann.cuh @@ -143,28 +143,13 @@ class scann_test : public ::testing::TestWithParam { stream_); raft::resource::sync_stream(handle_); - bool all_zeros = true; - bool has_nan = false; - // Check all codes are in valid range [0, num_pq_clusters] + bool all_zeros = true; auto n_vecs_to_check = std::min(ps.num_db_vecs, 50u); - for (IdxT i = 0; i < n_vecs_to_check * num_subspaces; i++) { + for (IdxT i = 0; (i < n_vecs_to_check * num_subspaces) && all_zeros; i++) { if (quant_res_host.data_handle()[i] != 0) { all_zeros = false; } if (quant_soar_host.data_handle()[i] != 0) { all_zeros = false; } - if (std::isnan(quant_res_host.data_handle()[i])) { - has_nan = true; - break; - } - if (std::isnan(quant_soar_host.data_handle()[i])) { - has_nan = true; - break; - } - ASSERT_LT(quant_res_host.data_handle()[i], num_pq_clusters) - << "AVQ quantized code out of range at position " << i; - ASSERT_LT(quant_soar_host.data_handle()[i], num_pq_clusters) - << "SOAR quantized code out of range at position " << i; } ASSERT_FALSE(all_zeros) << "Quantized output contains all zeros"; - ASSERT_FALSE(has_nan) << "Quantized output contains NaN values"; } void SetUp() override // NOLINT diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index d8a83e747a..f1b84854f1 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -206,7 +206,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam(n_samples_, n_encoded_cols); @@ -216,18 +216,13 @@ class ProductQuantizationTest : public ::testing::TestWithParam Date: Tue, 24 Feb 2026 10:14:49 -0800 Subject: [PATCH 8/8] Add SCANN reconstruction test Signed-off-by: Mickael Ide --- cpp/tests/neighbors/ann_scann.cuh | 118 +++++++++++++++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/cpp/tests/neighbors/ann_scann.cuh b/cpp/tests/neighbors/ann_scann.cuh index 5d1bcd2068..eafddec9d2 100644 --- a/cpp/tests/neighbors/ann_scann.cuh +++ b/cpp/tests/neighbors/ann_scann.cuh @@ -8,7 +8,10 @@ #include "ann_utils.cuh" #include #include +#include +#include +#include #include #include #include @@ -124,6 +127,7 @@ class scann_test : public ::testing::TestWithParam { ASSERT_EQ(index.bf16_dataset().size(), expected_bf16_size); check_code_validity(index, num_subspaces, num_pq_clusters); + check_reconstruction(index, num_subspaces); } void check_code_validity(const index& idx, int num_subspaces, int num_pq_clusters) @@ -145,13 +149,125 @@ class scann_test : public ::testing::TestWithParam { bool all_zeros = true; auto n_vecs_to_check = std::min(ps.num_db_vecs, 50u); - for (IdxT i = 0; (i < n_vecs_to_check * num_subspaces) && all_zeros; i++) { + for (IdxT i = 0; i < n_vecs_to_check * num_subspaces; i++) { if (quant_res_host.data_handle()[i] != 0) { all_zeros = false; } if (quant_soar_host.data_handle()[i] != 0) { all_zeros = false; } + // Check that unpacked codes are in valid range + if (ps.index_params.pq_bits == 4) { + ASSERT_LT(quant_res_host.data_handle()[i], num_pq_clusters) + << "AVQ quantized code out of range at position " << i; + ASSERT_LT(quant_soar_host.data_handle()[i], num_pq_clusters) + << "SOAR quantized code out of range at position " << i; + } } ASSERT_FALSE(all_zeros) << "Quantized output contains all zeros"; } + void check_reconstruction(const index& idx, int num_subspaces) + { + cuvs::preprocessing::quantize::pq::params pq_params; + pq_params.pq_bits = ps.index_params.pq_bits; + pq_params.pq_dim = num_subspaces; + pq_params.use_subspaces = true; + pq_params.use_vq = true; // SCANN uses centroids separately + + auto pq_codebook_copy = raft::make_device_matrix( + handle_, idx.pq_codebook().extent(0), idx.pq_codebook().extent(1)); + raft::copy(pq_codebook_copy.data_handle(), + idx.pq_codebook().data_handle(), + idx.pq_codebook().size(), + stream_); + + auto vq_codebook = raft::make_device_matrix( + handle_, idx.centers().extent(0), idx.centers().extent(1)); + raft::copy( + vq_codebook.data_handle(), idx.centers().data_handle(), idx.centers().size(), stream_); + auto empty_data = raft::make_device_matrix(handle_, 0, 0); + + cuvs::preprocessing::quantize::pq::quantizer quantizer{ + pq_params, + cuvs::neighbors::vpq_dataset{ + std::move(vq_codebook), std::move(pq_codebook_copy), std::move(empty_data)}}; + + auto quantized_residuals_device = + raft::make_device_matrix(handle_, ps.num_db_vecs, num_subspaces); + raft::copy(quantized_residuals_device.data_handle(), + idx.quantized_residuals().data_handle(), + idx.quantized_residuals().size(), + stream_); + + // Re-pack 4-bit codes. The 8-bit codes are already in the right format + auto codes_dim = cuvs::preprocessing::quantize::pq::get_quantized_dim(pq_params); + auto packed_codes = raft::make_device_matrix(handle_, ps.num_db_vecs, codes_dim); + + if (ps.index_params.pq_bits == 4) { + raft::linalg::map_offset( + handle_, + packed_codes.view(), + [qr_view = quantized_residuals_device.view(), num_subspaces, codes_dim] __device__( + size_t i) { + int64_t row_idx = i / codes_dim; + int64_t packed_idx = i % codes_dim; + int64_t code_idx = packed_idx * 2; + int64_t code_idx_next = code_idx + 1; + + uint8_t first_code = (code_idx < num_subspaces) ? qr_view(row_idx, code_idx) : 0; + uint8_t second_code = + (code_idx_next < num_subspaces) ? qr_view(row_idx, code_idx_next) : 0; + + return (first_code << 4) | (second_code & 0x0F); + }); + } else { + raft::copy(packed_codes.data_handle(), + quantized_residuals_device.data_handle(), + packed_codes.size(), + stream_); + } + + auto reconstructed_vectors = + raft::make_device_matrix(handle_, ps.num_db_vecs, ps.dim); + auto reconstructed_vectors_view = reconstructed_vectors.view(); + cuvs::preprocessing::quantize::pq::inverse_transform( + handle_, + quantizer, + raft::make_const_mdspan(packed_codes.view()), + reconstructed_vectors_view, + raft::make_const_mdspan(idx.labels())); + + // Compute L2 distances for reconstruction error + auto database_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + auto distances = raft::make_device_vector(handle_, ps.num_db_vecs); + raft::linalg::map_offset( + handle_, + distances.view(), + [database_view, reconstructed_vectors_view, dim = ps.dim] __device__(IdxT i) { + float dist = 0.0f; + for (uint32_t j = 0; j < dim; j++) { + float diff = database_view(i, j) - reconstructed_vectors_view(i, j); + dist += diff * diff; + } + return sqrtf(dist / static_cast(dim)); + }); + + float max_allowed_error = 0.95f; + auto distances_host = raft::make_host_vector(handle_, ps.num_db_vecs); + raft::copy(distances_host.data_handle(), distances.data_handle(), ps.num_db_vecs, stream_); + raft::resource::sync_stream(handle_); + + float mean_error = 0.0f; + float max_error = 0.0f; + for (IdxT i = 0; i < ps.num_db_vecs; i++) { + mean_error += distances_host(i); + max_error = std::max(max_error, distances_host(i)); + } + mean_error /= static_cast(ps.num_db_vecs); + ASSERT_LT(mean_error, max_allowed_error) + << "Mean reconstruction error too large: " << mean_error; + ASSERT_LT(max_error, max_allowed_error * 1.5f) + << "Max reconstruction error too large: " << max_error; + } + void SetUp() override // NOLINT { gen_data();