diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 637b9276c0..9d3198eb74 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -5,7 +5,7 @@ #pragma once #include "../../../core/nvtx.hpp" -#include "../../vpq_dataset.cuh" +#include "../../../preprocessing/quantize/vpq_build-ext.cuh" #include "graph_core.cuh" #include @@ -2279,8 +2279,7 @@ index build( idx.update_dataset( res, // TODO: hardcoding codebook math to `half`, we can do runtime dispatching later - cuvs::neighbors::vpq_build( - res, *params.compression, dataset)); + cuvs::preprocessing::quantize::pq::vpq_build(res, *params.compression, dataset)); return idx; } diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index a767d16530..75f236a65d 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -11,6 +11,9 @@ #include "search_plan.cuh" #include "search_single_cta.cuh" +#include +#include + #include namespace cuvs::neighbors::cagra::detail { diff --git a/cpp/src/neighbors/detail/vpq_dataset.cuh b/cpp/src/neighbors/detail/vpq_dataset.cuh index e609100a76..fc114dd215 100644 --- a/cpp/src/neighbors/detail/vpq_dataset.cuh +++ b/cpp/src/neighbors/detail/vpq_dataset.cuh @@ -467,6 +467,41 @@ void process_and_fill_codes( RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 16]", pq_bits); } }(pq_bits); + bool need_copy_to_device = + cuvs::spatial::knn::detail::utils::check_pointer_residency(dataset.data_handle()) == + cuvs::spatial::knn::detail::utils::pointer_residency::host_only; + bool need_batching = n_rows > kReasonableMaxBatchSize; + auto launch_work = [&](auto& dataset_view, auto& labels_view, auto& codes_view) { + if (inline_vq_labels || (!vq_labels.empty() && !vq_centers.empty())) { + predict_vq(res, dataset_view, vq_centers, labels_view); + } + dim3 blocks( + raft::div_rounding_up_safe(dataset_view.extent(0), kBlockSize / threads_per_vec), 1, 1); + kernel<<>>(codes_view, + dataset_view, + pq_centers, + vq_centers, + raft::make_const_mdspan(labels_view), + rows_in_shared_memory, + pq_bits, + inline_vq_labels); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + }; + auto batch_labels = raft::make_device_vector(res, 0); + if (!need_batching && !need_copy_to_device) { + // No batching needed, launch the kernel directly + auto dataset_view = raft::make_device_matrix_view(dataset.data_handle(), n_rows, dim); + auto labels_view = raft::make_device_vector_view(nullptr, 0); + if (inline_vq_labels) { + batch_labels = raft::make_device_vector(res, dataset_view.extent(0)); + labels_view = batch_labels.view(); + } else if (!vq_labels.empty() && !vq_centers.empty()) { + labels_view = vq_labels; + } + launch_work(dataset_view, labels_view, codes); + return; + } + for (const auto& batch : cuvs::spatial::knn::detail::utils::batch_load_iterator( dataset.data_handle(), n_rows, @@ -475,53 +510,20 @@ void process_and_fill_codes( stream, rmm::mr::get_current_device_resource())) { auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); - auto batch_labels = raft::make_device_vector(res, 0); auto batch_labels_view = raft::make_device_vector_view(nullptr, 0); if (inline_vq_labels) { batch_labels = raft::make_device_vector(res, batch.size()); batch_labels_view = batch_labels.view(); - predict_vq(res, batch_view, vq_centers, batch_labels_view); - } else { - if (!vq_labels.empty() && !vq_centers.empty()) { - batch_labels_view = raft::make_device_vector_view( - vq_labels.data_handle() + batch.offset(), batch.size()); - predict_vq(res, batch_view, vq_centers, batch_labels_view); - } + } else if (!vq_labels.empty() && !vq_centers.empty()) { + batch_labels_view = raft::make_device_vector_view( + vq_labels.data_handle() + batch.offset(), batch.size()); } - dim3 blocks(raft::div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); - kernel<<>>( - raft::make_device_matrix_view( - codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen), - batch_view, - pq_centers, - vq_centers, - raft::make_const_mdspan(batch_labels_view), - rows_in_shared_memory, - pq_bits, - inline_vq_labels); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + auto batch_codes_view = raft::make_device_matrix_view( + codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen); + launch_work(batch_view, batch_labels_view, batch_codes_view); } } -template -auto vpq_convert_math_type(const raft::resources& res, vpq_dataset&& src) - -> vpq_dataset -{ - auto vq_code_book = raft::make_device_mdarray(res, src.vq_code_book.extents()); - auto pq_code_book = raft::make_device_mdarray(res, src.pq_code_book.extents()); - - raft::linalg::map(res, - vq_code_book.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.vq_code_book.view())); - raft::linalg::map(res, - pq_code_book.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.pq_code_book.view())); - return vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)}; -} - // Helper for operations using vectorized loads of raft::TxN_t template struct vec_op : raft::TxN_t { @@ -858,14 +860,40 @@ void process_and_fill_codes_subspaces( } }(pq_bits); - ix_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); - auto copy_stream = raft::resource::get_cuda_stream(res); // Using the main stream by default - bool enable_prefetch = false; - if (res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL)) { - if (raft::resource::get_stream_pool_size(res) >= 1) { - enable_prefetch = true; - copy_stream = raft::resource::get_stream_from_stream_pool(res); + ix_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); + auto copy_stream = raft::resource::get_cuda_stream(res); // Using the main stream by default + bool enable_prefetch_stream = false; + bool has_cuda_stream_pool_resource = + res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) && + raft::resource::get_stream_pool_size(res) >= 1; + bool need_copy_to_device = + cuvs::spatial::knn::detail::utils::check_pointer_residency(dataset.data_handle()) == + cuvs::spatial::knn::detail::utils::pointer_residency::host_only; + bool need_batching = n_rows > kReasonableMaxBatchSize; + auto launch_work = [&](auto& dataset_view, auto& labels_view, auto& codes_view) { + if (!vq_labels.empty() && !vq_centers.empty()) { + predict_vq(res, dataset_view, vq_centers, labels_view); } + dim3 blocks( + raft::div_rounding_up_safe(dataset_view.extent(0), kBlockSize / threads_per_vec), 1, 1); + kernel<<>>(codes_view, + dataset_view, + pq_centers, + vq_centers, + raft::make_const_mdspan(labels_view), + pq_bits, + shared_memory_size > 0); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + }; + if (!need_batching && !need_copy_to_device) { + // No batching and no copy to device needed, launch the kernel directly + auto dataset_view = raft::make_device_matrix_view(dataset.data_handle(), n_rows, dim); + launch_work(dataset_view, vq_labels, codes); + return; + } + if (has_cuda_stream_pool_resource && need_copy_to_device) { + enable_prefetch_stream = true; + copy_stream = raft::resource::get_stream_from_stream_pool(res); } auto vec_batches = cuvs::spatial::knn::detail::utils::batch_load_iterator( dataset.data_handle(), @@ -874,7 +902,7 @@ void process_and_fill_codes_subspaces( max_batch_size, copy_stream, raft::resource::get_workspace_resource(res), - enable_prefetch); + enable_prefetch_stream); vec_batches.prefetch_next_batch(); for (const auto& batch : vec_batches) { auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); @@ -882,54 +910,14 @@ void process_and_fill_codes_subspaces( if (!vq_labels.empty() && !vq_centers.empty()) { batch_labels = raft::make_device_vector_view( vq_labels.data_handle() + batch.offset(), batch.size()); - predict_vq(res, batch_view, vq_centers, batch_labels); } - dim3 blocks(raft::div_rounding_up_safe(batch.size(), kBlockSize / threads_per_vec), 1, 1); - kernel<<>>( - raft::make_device_matrix_view( - codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen), - batch_view, - pq_centers, - vq_centers, - raft::make_const_mdspan(batch_labels), - pq_bits, - shared_memory_size > 0); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - vec_batches.prefetch_next_batch(); - raft::resource::sync_stream(res); + auto batch_codes_view = raft::make_device_matrix_view( + codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen); + launch_work(batch_view, batch_labels, batch_codes_view); + if (enable_prefetch_stream) { + vec_batches.prefetch_next_batch(); + raft::resource::sync_stream(res); + } } } - -template -auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> vpq_dataset -{ - using label_t = uint32_t; - // Use a heuristic to impute missing parameters. - auto ps = fill_missing_params_heuristics(params, dataset); - - // Train codes - auto vq_code_book = train_vq(res, ps, dataset); - auto pq_code_book = - train_pq(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); - - // Encode dataset - const IdxT n_rows = dataset.extent(0); - const IdxT codes_rowlen = sizeof(label_t) * (1 + raft::div_rounding_up_safe( - ps.pq_dim * ps.pq_bits, 8 * sizeof(label_t))); - - auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); - process_and_fill_codes(res, - ps, - dataset, - raft::make_const_mdspan(pq_code_book.view()), - raft::make_const_mdspan(vq_code_book.view()), - raft::make_device_vector_view(nullptr, 0), - codes.view(), - true); - - return vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; -} - } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/scann/detail/scann_build.cuh b/cpp/src/neighbors/scann/detail/scann_build.cuh index 7805f622d3..37f74f29bb 100644 --- a/cpp/src/neighbors/scann/detail/scann_build.cuh +++ b/cpp/src/neighbors/scann/detail/scann_build.cuh @@ -8,6 +8,7 @@ #include "../../detail/ann_utils.cuh" #include #include +#include #include #include @@ -159,47 +160,17 @@ index build( int dim_per_subspace = params.pq_dim; int num_clusters = 1 << params.pq_bits; - auto full_codebook = - raft::make_device_matrix(res, num_clusters * num_subspaces, dim_per_subspace); - - // Loop each subspace, training codebooks for each - for (int subspace = 0; subspace < num_subspaces; subspace++) { - int sub_dim_start = subspace * dim_per_subspace; - int sub_dim_end = (subspace + 1) * dim_per_subspace; - - auto sub_trainset = raft::make_device_matrix( - res, trainset_residuals.extent(0), (int64_t)dim_per_subspace); - raft::matrix::slice_coordinates avq_sub_coords( - 0, sub_dim_start, trainset_residuals.extent(0), sub_dim_end); - raft::matrix::slice( - res, raft::make_const_mdspan(trainset_residuals.view()), sub_trainset.view(), avq_sub_coords); - - // Set up quantization bits and params - cuvs::neighbors::vpq_params pq_params; - pq_params.pq_bits = params.pq_bits; - // For VPQ, pq_dim is the number of subspaces, not the dimension of the subspaces - pq_params.pq_dim = 1; - // We handle sampling/training set construction above, so use the full set in VPQ - pq_params.pq_kmeans_trainset_fraction = 1.0; - pq_params.kmeans_n_iters = params.pq_train_iters; - - // Create pq codebook for this subspace - auto sub_pq_codebook = - create_pq_codebook(res, raft::make_const_mdspan(sub_trainset.view()), pq_params); - - raft::copy( - res, - raft::make_device_vector_view( - full_codebook.data_handle() + (subspace * sub_pq_codebook.size()), sub_pq_codebook.size()), - raft::make_device_vector_view(sub_pq_codebook.data_handle(), - sub_pq_codebook.size())); - } - raft::resource::sync_stream(res); + cuvs::preprocessing::quantize::pq::params pq_build_params; + pq_build_params.pq_bits = params.pq_bits; + pq_build_params.pq_dim = num_subspaces; + pq_build_params.use_subspaces = true; + pq_build_params.use_vq = false; // We already computed residuals + pq_build_params.kmeans_n_iters = params.pq_train_iters; + pq_build_params.max_train_points_per_pq_code = pq_n_rows_train / num_clusters; + pq_build_params.pq_kmeans_type = cuvs::cluster::kmeans::kmeans_type::KMeansBalanced; - // Set up quantization bits and params - cuvs::neighbors::vpq_params pq_params; - pq_params.pq_bits = params.pq_bits; - pq_params.pq_dim = dataset.extent(1) / params.pq_dim; + auto pq_quantizer = cuvs::preprocessing::quantize::pq::build( + res, pq_build_params, raft::make_const_mdspan(trainset_residuals.view())); dataset_vec_batches.reset(); dataset_vec_batches.prefetch_next_batch(); @@ -230,9 +201,11 @@ index build( batch_soar_labels_view, params.soar_lambda); - // Compute and quantize residuals - auto avq_quant = quantize_residuals( - res, raft::make_const_mdspan(avq_residuals.view()), full_codebook.view(), pq_params); + // Compute and quantize residuals using the public PQ API + int64_t codes_dim = cuvs::preprocessing::quantize::pq::get_quantized_dim(pq_build_params); + auto avq_quant = raft::make_device_matrix(res, batch.size(), codes_dim); + cuvs::preprocessing::quantize::pq::transform( + res, pq_quantizer, raft::make_const_mdspan(avq_residuals.view()), avq_quant.view()); // Compute and quantize SOAR residuals auto soar_residuals = @@ -241,56 +214,64 @@ index build( raft::make_const_mdspan(centroids_view), raft::make_const_mdspan(batch_soar_labels_view)); - auto soar_quant = quantize_residuals( - res, raft::make_const_mdspan(soar_residuals.view()), full_codebook.view(), pq_params); + auto soar_quant = raft::make_device_matrix(res, batch.size(), codes_dim); + cuvs::preprocessing::quantize::pq::transform( + res, pq_quantizer, raft::make_const_mdspan(soar_residuals.view()), soar_quant.view()); + // Prefetch next batch + dataset_vec_batches.prefetch_next_batch(); // unpack codes - auto quantized_residuals = - raft::make_device_matrix(res, batch.size(), num_subspaces); - auto quantized_soar_residuals = - raft::make_device_matrix(res, batch.size(), num_subspaces); - - unpack_codes(res, - quantized_residuals.view(), - raft::make_const_mdspan(avq_quant.view()), - params.pq_bits, - num_subspaces); - unpack_codes(res, - quantized_soar_residuals.view(), - raft::make_const_mdspan(soar_quant.view()), - params.pq_bits, - num_subspaces); + if (pq_quantizer.params_quantizer.pq_bits == 8) { + // Copy unpacked codes to host + // TODO (rmaschal): these copies are blocking and not overlapped + raft::copy(idx.quantized_residuals().data_handle() + batch.offset() * num_subspaces, + avq_quant.data_handle(), + avq_quant.size(), + stream); + + raft::copy(idx.quantized_soar_residuals().data_handle() + batch.offset() * num_subspaces, + soar_quant.data_handle(), + soar_quant.size(), + stream); + } else { + auto quantized_residuals = + raft::make_device_matrix(res, batch.size(), num_subspaces); + auto quantized_soar_residuals = + raft::make_device_matrix(res, batch.size(), num_subspaces); + + unpack_codes(res, + quantized_residuals.view(), + raft::make_const_mdspan(avq_quant.view()), + params.pq_bits, + num_subspaces); + unpack_codes(res, + quantized_soar_residuals.view(), + raft::make_const_mdspan(soar_quant.view()), + params.pq_bits, + num_subspaces); + raft::copy(res, + raft::make_host_vector_view( + idx.quantized_residuals().data_handle() + batch.offset() * num_subspaces, + quantized_residuals.size()), + raft::make_device_vector_view(quantized_residuals.data_handle(), + quantized_residuals.size())); + + raft::copy(res, + raft::make_host_vector_view( + idx.quantized_soar_residuals().data_handle() + batch.offset() * num_subspaces, + quantized_soar_residuals.size()), + raft::make_device_vector_view( + quantized_soar_residuals.data_handle(), quantized_soar_residuals.size())); + } // quantize dataset to bfloat16, if enabled. Similar to SOAR, quantization // is performed in this loop to improve locality // TODO (rmaschal): Might be more efficient to do on CPU, to avoid DtoH copy - auto bf16_dataset = raft::make_device_matrix(res, batch_view.extent(0), dim); - if (params.reordering_bf16) { + auto bf16_dataset = + raft::make_device_matrix(res, batch_view.extent(0), dim); quantize_bfloat16( res, batch_view, bf16_dataset.view(), params.reordering_noise_shaping_threshold); - } - - // Prefetch next batch - dataset_vec_batches.prefetch_next_batch(); - - // Copy unpacked codes to host - // TODO (rmaschal): these copies are blocking and not overlapped - raft::copy(res, - raft::make_host_vector_view( - idx.quantized_residuals().data_handle() + batch.offset() * num_subspaces, - quantized_residuals.size()), - raft::make_device_vector_view(quantized_residuals.data_handle(), - quantized_residuals.size())); - - raft::copy(res, - raft::make_host_vector_view( - idx.quantized_soar_residuals().data_handle() + batch.offset() * num_subspaces, - quantized_soar_residuals.size()), - raft::make_device_vector_view(quantized_soar_residuals.data_handle(), - quantized_soar_residuals.size())); - - if (params.reordering_bf16) { raft::copy(res, raft::make_host_vector_view( idx.bf16_dataset().data_handle() + batch.offset() * dim, bf16_dataset.size()), @@ -305,7 +286,7 @@ index build( // Codebooks from VPQ have the shape [subspace idx, subspace dim, code] // This converts the codebook into matrix format for easy interoperability // with open-source ScaNN search - auto full_codebook_view = full_codebook.view(); + auto full_codebook_view = pq_quantizer.vpq_codebooks.pq_code_book.view(); raft::linalg::map_offset( res, diff --git a/cpp/src/neighbors/scann/detail/scann_quantize.cuh b/cpp/src/neighbors/scann/detail/scann_quantize.cuh index 16ef1f4295..6d5a511189 100644 --- a/cpp/src/neighbors/scann/detail/scann_quantize.cuh +++ b/cpp/src/neighbors/scann/detail/scann_quantize.cuh @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "../../detail/vpq_dataset.cuh" -#include "../../ivf_pq/ivf_pq_codepacking.cuh" #include #include #include @@ -21,125 +19,6 @@ namespace cuvs::neighbors::experimental::scann::detail { /** Fix the internal indexing type to avoid integer underflows/overflows */ using ix_t = int64_t; -template -__launch_bounds__(BlockSize) RAFT_KERNEL process_and_fill_codes_subspaces_kernel( - raft::device_matrix_view out_codes, - raft::device_matrix_view dataset, - raft::device_matrix_view vq_centers, - raft::device_vector_view vq_labels, - raft::device_matrix_view pq_centers) -{ - constexpr uint32_t kSubWarpSize = std::min(raft::WarpSize, 1u << PqBits); - using subwarp_align = raft::Pow2; - const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); - if (row_ix >= out_codes.extent(0)) { return; } - - const uint32_t pq_dim = raft::div_rounding_up_unsafe(vq_centers.extent(1), pq_centers.extent(1)); - - const uint32_t lane_id = raft::Pow2::mod(threadIdx.x); - const LabelT vq_label = vq_labels(row_ix); - - // write label - auto* out_label_ptr = reinterpret_cast(&out_codes(row_ix, 0)); - if (lane_id == 0) { *out_label_ptr = vq_label; } - - auto* out_codes_ptr = reinterpret_cast(out_label_ptr + 1); - cuvs::neighbors::ivf_pq::detail::bitfield_view_t code_view{out_codes_ptr}; - for (uint32_t j = 0; j < pq_dim; j++) { - // find PQ label - int subspace_offset = j * pq_centers.extent(1) * (1 << PqBits); - auto pq_subspace_view = raft::make_device_matrix_view( - pq_centers.data_handle() + subspace_offset, (uint32_t)(1 << PqBits), pq_centers.extent(1)); - auto pq_centers_smem = - raft::make_device_matrix_view(nullptr, 0, 0); - uint8_t code = cuvs::neighbors::detail::compute_code( - dataset, vq_centers, pq_centers_smem, pq_subspace_view, row_ix, j, vq_label); - // TODO: this writes in global memory one byte per warp, which is very slow. - // It's better to keep the codes in the shared memory or registers and dump them at once. - if (lane_id == 0) { code_view[j] = code; } - } -} - -template -auto process_and_fill_codes_subspaces( - const raft::resources& res, - const vpq_params& params, - const DatasetT& dataset, - raft::device_matrix_view vq_centers, - raft::device_matrix_view pq_centers) - -> raft::device_matrix -{ - using data_t = typename DatasetT::value_type; - using cdataset_t = vpq_dataset; - using label_t = uint32_t; - - const ix_t n_rows = dataset.extent(0); - const ix_t dim = dataset.extent(1); - const ix_t pq_dim = params.pq_dim; - const ix_t pq_bits = params.pq_bits; - const ix_t pq_n_centers = ix_t{1} << pq_bits; - // NB: codes must be aligned at least to sizeof(label_t) to be able to read labels. - const ix_t codes_rowlen = - sizeof(label_t) * (1 + raft::div_rounding_up_safe(pq_dim * pq_bits, 8 * sizeof(label_t))); - - auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); - - auto stream = raft::resource::get_cuda_stream(res); - - // TODO: with scaling workspace we could choose the batch size dynamically - constexpr ix_t kBlockSize = 256; - const ix_t threads_per_vec = std::min(raft::WarpSize, pq_n_centers); - dim3 threads(kBlockSize, 1, 1); - - auto kernel = [](uint32_t pq_bits) { - switch (pq_bits) { - case 4: - return process_and_fill_codes_subspaces_kernel; - case 8: - return process_and_fill_codes_subspaces_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be 4 or 8", pq_bits); - } - }(pq_bits); - - auto labels = raft::make_device_vector(res, dataset.extent(0)); - cuvs::neighbors::detail::predict_vq(res, dataset, vq_centers, labels.view()); - - dim3 blocks(raft::div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); - - kernel<<>>( - raft::make_device_matrix_view(codes.data_handle(), n_rows, codes_rowlen), - dataset, - vq_centers, - raft::make_const_mdspan(labels.view()), - pq_centers); - - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - return codes; -} - -template -auto create_pq_codebook(raft::resources const& res, - raft::device_matrix_view residuals, - cuvs::neighbors::vpq_params ps) - -> raft::device_matrix -{ - // Create codebooks (vq initialized to 0s since we don't need here) - auto vq_code_book = - raft::make_device_matrix(res, 1, residuals.extent(1)); - raft::linalg::map_offset(res, vq_code_book.view(), [] __device__(size_t i) { return 0; }); - - auto pq_code_book = cuvs::neighbors::detail::train_pq( - res, ps, residuals, raft::make_const_mdspan(vq_code_book.view())); - - return pq_code_book; -} - /** * @brief Subtract cluster center coordinates from each dataset vector. * @@ -175,61 +54,19 @@ auto compute_residuals(raft::resources const& res, return residuals; } -/**} - * @brief Generate PQ codes for residual vectors using codebook - * - * For each subspace, minimize L2 norm between residual vectors and - * PQ centers to generate codes for residual vectors - * - * @tparam T - * @tparam IdxT - * @tparam LabelT - * @param res raft resources - * @param residuals the residual vectors we're quantizing, size [n_rows, dim] - * @param pq_codebook the codebook of PQ centers size [dim, 1 << pq_bits] - * @oaran ps parameters used with vpq_dataset for pq quantization - * @return device matrix with (packed) codes from vpq, size [n_rows, 1 +ceil((dim / pq_dim * - * pq_bits) /( 8 * sizeof(LabelT)))] - */ -template -auto quantize_residuals(raft::resources const& res, - raft::device_matrix_view residuals, - raft::device_matrix_view pq_codebook, - cuvs::neighbors::vpq_params ps) - -> raft::device_matrix -{ - auto dim = residuals.extent(1); - - // Using a single 0 vector for the vq_codebook, since we already have - // vq centers and computed residuals w.r.t those centers - auto vq_codebook = raft::make_device_matrix(res, 1, dim); - - raft::matrix::fill(res, vq_codebook.view(), T(0)); - - auto codes = process_and_fill_codes_subspaces( - res, ps, residuals, raft::make_const_mdspan(vq_codebook.view()), pq_codebook); - - return codes; -} - /** * @brief Unpack VPQ codes into 1-byte per code * - * VPQ gives codes in a "packed" form. The first 4 bytes give the code for - * vector quantization, and the remaining bytes the codes for subspace product - * quantization. In the case of 4 bit PQ, each byte stores codes for 2 subspaces - * in a packed form. + * VPQ gives codes in a "packed" form. In the case of 4 bit PQ, each byte stores + * codes for 2 subspaces in a packed form. * - * This function unpacks the codes by discarding the VQ code (which we don't need, - * since we use VPQ only for residual quantization) and (in the case of 4-bit PQ) - * unpackes the subspace codes into one byte each. This is for interoperability - * with open source ScaNN, which doesn't pack codes + * This function unpacks the subspace codes into one byte each. This is for + * interoperability with open source ScaNN, which doesn't pack codes * * @tparam IdxT * @param res raft resources * @param unpacked_codes_view matrix of unpacked codes, size [n_rows, dim / pq_dim] - * @param codes_view packed codes from vpq, size [n_rows, 1 +ceil((dim / pq_dim * pq_bits) /( 8 * - * sizeof(LabelT)))] + * @param codes_view packed codes from vpq, size [n_rows, ceil((dim / pq_dim * pq_bits) / 8)] * @param pq_bits number of bits used for PQ * @param num_subspaces the number of pq_subspaces (dim / pq_dim) */ @@ -245,7 +82,7 @@ void unpack_codes(raft::resources const& res, res, unpacked_codes_view, [codes_view, num_subspaces] __device__(size_t i) { int64_t row_idx = i / num_subspaces; int64_t subspace_idx = i % num_subspaces; - int64_t packed_subspace_idx = 4 + subspace_idx / 2; + int64_t packed_subspace_idx = subspace_idx / 2; uint8_t mask = subspace_idx % 2; uint8_t packed_labels = codes_view(row_idx, packed_subspace_idx); @@ -254,10 +91,6 @@ void unpack_codes(raft::resources const& res, return (mask)*first + (1 - mask) * second; }); - - } else { - raft::matrix::slice_coordinates coords(0, 4, codes_view.extent(0), 4 + num_subspaces); - raft::matrix::slice(res, raft::make_const_mdspan(codes_view), unpacked_codes_view, coords); } } diff --git a/cpp/src/neighbors/vpq_dataset.cuh b/cpp/src/neighbors/vpq_dataset.cuh deleted file mode 100644 index 34b011bb0b..0000000000 --- a/cpp/src/neighbors/vpq_dataset.cuh +++ /dev/null @@ -1,41 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include "detail/vpq_dataset.cuh" -#include - -#include - -namespace cuvs::neighbors { - -/** - * @brief Compress a dataset for use in CAGRA-Q search in place of the original data. - * - * @tparam DatasetT a row-major mdspan or mdarray (device or host). - * @tparam MathT a type of the codebook elements and internal math ops. - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] res - * @param[in] params VQ and PQ parameters for compressing the data - * @param[in] dataset a row-major mdspan or mdarray (device or host) [n_rows, dim]. - */ -template -auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> vpq_dataset - -{ - if constexpr (std::is_same_v) { - return detail::vpq_convert_math_type( - res, detail::vpq_build(res, params, dataset)); - } else { - return detail::vpq_build(res, params, dataset); - } -} - -} // namespace cuvs::neighbors diff --git a/cpp/src/preprocessing/quantize/detail/pq.cuh b/cpp/src/preprocessing/quantize/detail/pq.cuh index db7cc8d5c1..e04681b974 100644 --- a/cpp/src/preprocessing/quantize/detail/pq.cuh +++ b/cpp/src/preprocessing/quantize/detail/pq.cuh @@ -328,4 +328,67 @@ void inverse_transform( out, quant.params_quantizer.use_subspaces); } + +template +void vpq_convert_math_type(const raft::resources& res, + const cuvs::neighbors::vpq_dataset& src, + cuvs::neighbors::vpq_dataset& dst) +{ + raft::linalg::map(res, + dst.vq_code_book.view(), + cuvs::spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.vq_code_book.view())); + raft::linalg::map(res, + dst.pq_code_book.view(), + cuvs::spatial::knn::detail::utils::mapping{}, + raft::make_const_mdspan(src.pq_code_book.view())); +} + +template +auto vpq_build(const raft::resources& res, + const cuvs::neighbors::vpq_params& params, + const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset +{ + using label_t = uint32_t; + // Use a heuristic to impute missing parameters. + auto ps = cuvs::neighbors::detail::fill_missing_params_heuristics(params, dataset); + + // Train codes + auto vq_code_book = cuvs::neighbors::detail::train_vq(res, ps, dataset); + auto pq_code_book = cuvs::neighbors::detail::train_pq( + res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); + + // Encode dataset + const IdxT n_rows = dataset.extent(0); + const IdxT codes_rowlen = sizeof(label_t) * (1 + raft::div_rounding_up_safe( + ps.pq_dim * ps.pq_bits, 8 * sizeof(label_t))); + + auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); + cuvs::neighbors::detail::process_and_fill_codes( + res, + ps, + dataset, + raft::make_const_mdspan(pq_code_book.view()), + raft::make_const_mdspan(vq_code_book.view()), + raft::make_device_vector_view(nullptr, 0), + codes.view(), + true); + + return cuvs::neighbors::vpq_dataset{ + std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; +} + +template +auto vpq_build_half(const raft::resources& res, + const cuvs::neighbors::vpq_params& params, + const DatasetT& dataset) -> cuvs::neighbors::vpq_dataset +{ + auto old_type = vpq_build(res, params, dataset); + auto new_type = cuvs::neighbors::vpq_dataset{ + raft::make_device_mdarray(res, old_type.vq_code_book.extents()), + raft::make_device_mdarray(res, old_type.pq_code_book.extents()), + std::move(old_type.data)}; + vpq_convert_math_type(res, old_type, new_type); + return new_type; +} } // namespace cuvs::preprocessing::quantize::pq::detail diff --git a/cpp/src/preprocessing/quantize/pq.cu b/cpp/src/preprocessing/quantize/pq.cu index 4a381c59ca..761474bdf8 100644 --- a/cpp/src/preprocessing/quantize/pq.cu +++ b/cpp/src/preprocessing/quantize/pq.cu @@ -52,4 +52,25 @@ CUVS_INST_QUANTIZATION(float, uint8_t); #undef CUVS_INST_QUANTIZATION +#define CUVS_INST_VPQ_BUILD(T) \ + auto vpq_build(const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::host_matrix_view& dataset) \ + { \ + return detail::vpq_build_half(res, params, dataset); \ + } \ + auto vpq_build(const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::device_matrix_view& dataset) \ + { \ + return detail::vpq_build_half(res, params, dataset); \ + } + +CUVS_INST_VPQ_BUILD(float); +CUVS_INST_VPQ_BUILD(half); +CUVS_INST_VPQ_BUILD(int8_t); +CUVS_INST_VPQ_BUILD(uint8_t); + +#undef CUVS_INST_VPQ_BUILD + } // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/src/preprocessing/quantize/vpq_build-ext.cuh b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh new file mode 100644 index 0000000000..1745e53a33 --- /dev/null +++ b/cpp/src/preprocessing/quantize/vpq_build-ext.cuh @@ -0,0 +1,28 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +namespace cuvs::preprocessing::quantize::pq { + +#define CUVS_INST_VPQ_BUILD(T) \ + cuvs::neighbors::vpq_dataset vpq_build( \ + const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::host_matrix_view& dataset); \ + cuvs::neighbors::vpq_dataset vpq_build( \ + const raft::resources& res, \ + const cuvs::neighbors::vpq_params& params, \ + const raft::device_matrix_view& dataset); + +CUVS_INST_VPQ_BUILD(float); +CUVS_INST_VPQ_BUILD(half); +CUVS_INST_VPQ_BUILD(int8_t); +CUVS_INST_VPQ_BUILD(uint8_t); + +#undef CUVS_INST_VPQ_BUILD +} // namespace cuvs::preprocessing::quantize::pq diff --git a/cpp/tests/neighbors/ann_scann.cuh b/cpp/tests/neighbors/ann_scann.cuh index d8bc27fed7..eafddec9d2 100644 --- a/cpp/tests/neighbors/ann_scann.cuh +++ b/cpp/tests/neighbors/ann_scann.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -8,7 +8,10 @@ #include "ann_utils.cuh" #include #include +#include +#include +#include #include #include #include @@ -123,6 +126,146 @@ class scann_test : public ::testing::TestWithParam { IdxT expected_bf16_size = ps.index_params.reordering_bf16 ? ps.dim * ps.num_db_vecs : 0; ASSERT_EQ(index.bf16_dataset().size(), expected_bf16_size); + check_code_validity(index, num_subspaces, num_pq_clusters); + check_reconstruction(index, num_subspaces); + } + + void check_code_validity(const index& idx, int num_subspaces, int num_pq_clusters) + { + auto quant_res_host = + raft::make_host_matrix(handle_, ps.num_db_vecs, num_subspaces); + auto quant_soar_host = + raft::make_host_matrix(handle_, ps.num_db_vecs, num_subspaces); + + raft::copy(quant_res_host.data_handle(), + idx.quantized_residuals().data_handle(), + idx.quantized_residuals().size(), + stream_); + raft::copy(quant_soar_host.data_handle(), + idx.quantized_soar_residuals().data_handle(), + idx.quantized_soar_residuals().size(), + stream_); + raft::resource::sync_stream(handle_); + + bool all_zeros = true; + auto n_vecs_to_check = std::min(ps.num_db_vecs, 50u); + for (IdxT i = 0; i < n_vecs_to_check * num_subspaces; i++) { + if (quant_res_host.data_handle()[i] != 0) { all_zeros = false; } + if (quant_soar_host.data_handle()[i] != 0) { all_zeros = false; } + // Check that unpacked codes are in valid range + if (ps.index_params.pq_bits == 4) { + ASSERT_LT(quant_res_host.data_handle()[i], num_pq_clusters) + << "AVQ quantized code out of range at position " << i; + ASSERT_LT(quant_soar_host.data_handle()[i], num_pq_clusters) + << "SOAR quantized code out of range at position " << i; + } + } + ASSERT_FALSE(all_zeros) << "Quantized output contains all zeros"; + } + + void check_reconstruction(const index& idx, int num_subspaces) + { + cuvs::preprocessing::quantize::pq::params pq_params; + pq_params.pq_bits = ps.index_params.pq_bits; + pq_params.pq_dim = num_subspaces; + pq_params.use_subspaces = true; + pq_params.use_vq = true; // SCANN uses centroids separately + + auto pq_codebook_copy = raft::make_device_matrix( + handle_, idx.pq_codebook().extent(0), idx.pq_codebook().extent(1)); + raft::copy(pq_codebook_copy.data_handle(), + idx.pq_codebook().data_handle(), + idx.pq_codebook().size(), + stream_); + + auto vq_codebook = raft::make_device_matrix( + handle_, idx.centers().extent(0), idx.centers().extent(1)); + raft::copy( + vq_codebook.data_handle(), idx.centers().data_handle(), idx.centers().size(), stream_); + auto empty_data = raft::make_device_matrix(handle_, 0, 0); + + cuvs::preprocessing::quantize::pq::quantizer quantizer{ + pq_params, + cuvs::neighbors::vpq_dataset{ + std::move(vq_codebook), std::move(pq_codebook_copy), std::move(empty_data)}}; + + auto quantized_residuals_device = + raft::make_device_matrix(handle_, ps.num_db_vecs, num_subspaces); + raft::copy(quantized_residuals_device.data_handle(), + idx.quantized_residuals().data_handle(), + idx.quantized_residuals().size(), + stream_); + + // Re-pack 4-bit codes. The 8-bit codes are already in the right format + auto codes_dim = cuvs::preprocessing::quantize::pq::get_quantized_dim(pq_params); + auto packed_codes = raft::make_device_matrix(handle_, ps.num_db_vecs, codes_dim); + + if (ps.index_params.pq_bits == 4) { + raft::linalg::map_offset( + handle_, + packed_codes.view(), + [qr_view = quantized_residuals_device.view(), num_subspaces, codes_dim] __device__( + size_t i) { + int64_t row_idx = i / codes_dim; + int64_t packed_idx = i % codes_dim; + int64_t code_idx = packed_idx * 2; + int64_t code_idx_next = code_idx + 1; + + uint8_t first_code = (code_idx < num_subspaces) ? qr_view(row_idx, code_idx) : 0; + uint8_t second_code = + (code_idx_next < num_subspaces) ? qr_view(row_idx, code_idx_next) : 0; + + return (first_code << 4) | (second_code & 0x0F); + }); + } else { + raft::copy(packed_codes.data_handle(), + quantized_residuals_device.data_handle(), + packed_codes.size(), + stream_); + } + + auto reconstructed_vectors = + raft::make_device_matrix(handle_, ps.num_db_vecs, ps.dim); + auto reconstructed_vectors_view = reconstructed_vectors.view(); + cuvs::preprocessing::quantize::pq::inverse_transform( + handle_, + quantizer, + raft::make_const_mdspan(packed_codes.view()), + reconstructed_vectors_view, + raft::make_const_mdspan(idx.labels())); + + // Compute L2 distances for reconstruction error + auto database_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + auto distances = raft::make_device_vector(handle_, ps.num_db_vecs); + raft::linalg::map_offset( + handle_, + distances.view(), + [database_view, reconstructed_vectors_view, dim = ps.dim] __device__(IdxT i) { + float dist = 0.0f; + for (uint32_t j = 0; j < dim; j++) { + float diff = database_view(i, j) - reconstructed_vectors_view(i, j); + dist += diff * diff; + } + return sqrtf(dist / static_cast(dim)); + }); + + float max_allowed_error = 0.95f; + auto distances_host = raft::make_host_vector(handle_, ps.num_db_vecs); + raft::copy(distances_host.data_handle(), distances.data_handle(), ps.num_db_vecs, stream_); + raft::resource::sync_stream(handle_); + + float mean_error = 0.0f; + float max_error = 0.0f; + for (IdxT i = 0; i < ps.num_db_vecs; i++) { + mean_error += distances_host(i); + max_error = std::max(max_error, distances_host(i)); + } + mean_error /= static_cast(ps.num_db_vecs); + ASSERT_LT(mean_error, max_allowed_error) + << "Mean reconstruction error too large: " << mean_error; + ASSERT_LT(max_error, max_allowed_error * 1.5f) + << "Max reconstruction error too large: " << max_error; } void SetUp() override // NOLINT diff --git a/cpp/tests/preprocessing/product_quantization.cu b/cpp/tests/preprocessing/product_quantization.cu index d8a83e747a..f1b84854f1 100644 --- a/cpp/tests/preprocessing/product_quantization.cu +++ b/cpp/tests/preprocessing/product_quantization.cu @@ -206,7 +206,7 @@ class ProductQuantizationTest : public ::testing::TestWithParam(n_samples_, n_encoded_cols); @@ -216,18 +216,13 @@ class ProductQuantizationTest : public ::testing::TestWithParam