Skip to content

Commit 35cda39

Browse files
authored
SCaNN Index build (rapidsai#1120)
This PR gives a proof-of-concept implementation of GPU-based index build for the ScaNN index. The ScaNN index defined here is similar to IVF-PQ index in structure (a tree structure coming from kmeans, plus product quantization of vectors assigned to leaf nodes), together with “AVQ update” of the kmeans centroids and a spilled cluster assignment from the “SOAR” loss. Other features, optimizations, and customizability options to appear in subsequent PRs. * scann_build.cuh This file contains the implementation for build(..). The general pipeline looks like: Train kmeans centers on sampled data Assign all dataset vectors to kmeans clusters by minimizing L2 loss Update kmeans centers with AVQ Train PQ codebook on sampled residual vectors (here we use VPQ, slightly modified to perform product quantization on individual subspaces, e.g. each subspace has its own codebook) Quantization loop (batched): Compute spilled SOAR labels (performed here to minimize HtoD copies) Compute and quantize residuals/soar residuals using trained pq codebook If enabled, compute bf16 quantization of dataset vectors (performed here to minimize HtoD copies). * scann_avq.cuh This file contains apply_avq(..), which recomputes cluster centers using AVQ. The main technique is a single application of Theorem 4.2 in https://arxiv.org/pdf/1908.10396 to each cluster, using parameters: h_i_parallel = eta * || x_i || ^ (eta - 1) h_i_orthogonal = ||x _i || ^ (eta -1) The implementation of Theorem 4.2 is in compute_avq_centroid(..) The overall pipeline for apply_avq(..) is: Build clusters from kmeans cluster assignments For each cluster: Gather cluster vectors into single matrix Update kmeans center via compute_avq_centroid Rescale updated centroids (I need to add more details about this step). * scann_quantize.cuh This file contains helpers for PQ. Codebooks are created from residual vectors using train_pq from vpq_dataset.cuh (using a single vq center which is set to zero). Unlike in VPQ, codebooks are generated separately for each subspace, rather than collapsing all subspaces into a single space and computing a global codebook. * scann_soar.cuh The main function is compute_soar_labels(..), which computes a second, spilled cluster assignment by minimizing the SOAR loss function (Theorem 3.1 in https://arxiv.org/pdf/2404.00774) * scann_serialize.cuh Contains the implementation of serialize(..). The goal is to serialize ScaNN artifacts in a way that is usable with open-source ScaNN search with minimal additional post-processing. The cluster assignments, quantized vectors (for both the primary and spilled SOAR assignments), and bf16 dataset are all stored in separate .npy files for direct consumption by open-source ScaNN. The codebook and cluster centers are also serialized separately, but require additional post-processing into the correct Protobuf structs (not included in this PR). Test Plan: This code is mostly tested via CPU search with open-source ScaNN. Additional protobuf artifacts are created from the cuVS serialized index via an external tool. A pareto for OpenAI 5M is shown here: Authors: - https://github.com/rmaschal Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Artem M. Chirkin (https://github.com/achirkin) URL: rapidsai#1120
1 parent 10e0795 commit 35cda39

File tree

16 files changed

+2446
-1
lines changed

16 files changed

+2446
-1
lines changed

cpp/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,8 @@ if(BUILD_SHARED_LIBS)
488488
src/neighbors/refine/detail/refine_host_half_float.cpp
489489
src/neighbors/refine/detail/refine_host_int8_t_float.cpp
490490
src/neighbors/refine/detail/refine_host_uint8_t_float.cpp
491+
src/neighbors/scann/scann_build_float.cu
492+
src/neighbors/scann/scann_serialize_float.cu
491493
src/neighbors/sample_filter.cu
492494
src/neighbors/tiered_index.cu
493495
src/neighbors/sparse_brute_force.cu
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
/*
2+
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include <cuvs/distance/distance.hpp>
20+
#include <cuvs/neighbors/common.hpp>
21+
#include <raft/core/device_mdarray.hpp>
22+
#include <raft/core/device_mdspan.hpp>
23+
#include <raft/core/host_device_accessor.hpp>
24+
#include <raft/core/host_mdarray.hpp>
25+
#include <raft/core/host_mdspan.hpp>
26+
#include <raft/core/mdspan.hpp>
27+
#include <raft/core/mdspan_types.hpp>
28+
#include <raft/core/resource/stream_view.hpp>
29+
#include <raft/core/resources.hpp>
30+
#include <raft/util/integer_utils.hpp>
31+
#include <rmm/cuda_stream_view.hpp>
32+
33+
#include <optional>
34+
#include <variant>
35+
36+
namespace cuvs::neighbors::experimental::scann {
37+
/**
38+
* @defgroup scann_cpp_index_params ScaNN index build parameters
39+
* @{
40+
*/
41+
42+
/**
43+
* @brief ANN parameters used by ScaNN to build index
44+
*
45+
*/
46+
struct index_params : cuvs::neighbors::index_params {
47+
// partitioning parameters
48+
49+
/** the number of leaves in the tree **/
50+
uint32_t n_leaves = 1000;
51+
/** the number of rows for training the tree structures **/
52+
int64_t kmeans_n_rows_train = 100000;
53+
54+
/** the max number of iterations for training the tree structure **/
55+
uint32_t kmeans_n_iters = 24;
56+
57+
/** the value of eta for AVQ adjustment during partitioning **/
58+
float partitioning_eta = 1;
59+
60+
/** the value of lambda for SOAR spilling **/
61+
float soar_lambda = 1;
62+
63+
// Residual quanitzation params
64+
/** the dimension of pq subspaces (must divide dataset dimension)**/
65+
uint32_t pq_dim = 8;
66+
67+
/** the number of bits for pq codes (must be 4 or 8, for 16 and 256 codes respectively) **/
68+
uint32_t pq_bits = 8;
69+
70+
/** the number of rows for PQ training (internally capped to 100k) **/
71+
int64_t pq_n_rows_train = 100000;
72+
73+
/** the max number of iterations for PQ training **/
74+
uint32_t pq_train_iters = 10;
75+
76+
/** whether to apply bf16 quantization of dataset vectors **/
77+
bool bf16_enabled = false;
78+
79+
// TODO - add other scann build params
80+
};
81+
82+
/**
83+
* @}
84+
*/
85+
86+
static_assert(std::is_aggregate_v<index_params>);
87+
88+
/**
89+
* @defgroup scann_cpp_index ScaNN index type
90+
* @{
91+
*/
92+
93+
/**
94+
* @brief ScaNN index.
95+
*
96+
* The index stores the dataset and the ScaNN graph in device memory.
97+
*
98+
* @tparam T data element type
99+
* @tparam IdxT type of the vector indices (represent dataset.extent(0))
100+
*
101+
*/
102+
template <typename T, typename IdxT>
103+
struct index : cuvs::neighbors::index {
104+
static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
105+
"IdxT must be able to represent all values of uint32_t");
106+
107+
public:
108+
/** Distance metric used for clustering. */
109+
[[nodiscard]] constexpr inline auto metric() const noexcept -> cuvs::distance::DistanceType
110+
{
111+
return metric_;
112+
}
113+
114+
/** Total length of the index (number of vectors). */
115+
IdxT size() const noexcept;
116+
117+
/** Dimensionality of the data. */
118+
[[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return dim_; }
119+
120+
// Don't allow copying the index for performance reasons (try avoiding copying data)
121+
index(const index&) = delete;
122+
index(index&&) = default;
123+
auto operator=(const index&) -> index& = delete;
124+
auto operator=(index&&) -> index& = default;
125+
~index() = default;
126+
127+
/** Construct an empty index. It will need to be trained and populated with `build`*/
128+
// index(raft::resources const& res) {}
129+
130+
index(raft::resources const& res,
131+
cuvs::distance::DistanceType metric,
132+
uint32_t n_leaves,
133+
uint32_t pq_bits,
134+
uint32_t pq_dim,
135+
IdxT n_rows,
136+
IdxT dim,
137+
uint32_t pq_clusters,
138+
uint32_t pq_num_subspaces,
139+
bool bf16_enabled)
140+
: cuvs::neighbors::index(),
141+
metric_(metric),
142+
pq_dim_(pq_dim),
143+
pq_bits_(pq_bits),
144+
n_leaves_(n_leaves),
145+
centers_(raft::make_device_matrix<float, IdxT>(res, n_leaves, dim)),
146+
labels_(raft::make_device_vector<uint32_t, IdxT>(res, n_rows)),
147+
soar_labels_(raft::make_device_vector<uint32_t, IdxT>(res, n_rows)),
148+
pq_codebook_(
149+
raft::make_device_matrix<float, uint32_t, raft::row_major>(res, pq_clusters, dim)),
150+
quantized_residuals_(
151+
raft::make_host_matrix<uint8_t, IdxT, raft::row_major>(n_rows, pq_num_subspaces)),
152+
quantized_soar_residuals_(
153+
raft::make_host_matrix<uint8_t, IdxT, raft::row_major>(n_rows, pq_num_subspaces)),
154+
n_rows_(n_rows),
155+
dim_(dim),
156+
bf16_dataset_(raft::make_host_matrix<int16_t, IdxT, raft::row_major>(
157+
bf16_enabled ? n_rows : 0, bf16_enabled ? dim : 0))
158+
159+
{
160+
}
161+
162+
index(raft::resources const& res, const index_params& params, IdxT n_rows, IdxT dim)
163+
: index(res,
164+
params.metric,
165+
params.n_leaves,
166+
params.pq_bits,
167+
params.pq_dim,
168+
n_rows,
169+
dim,
170+
1 << params.pq_bits,
171+
dim / params.pq_dim,
172+
params.bf16_enabled)
173+
{
174+
RAFT_EXPECTS(params.pq_bits == 4 || params.pq_bits == 8, "ScaNN only supports 4 or 8 bit PQ");
175+
RAFT_EXPECTS(dim >= params.pq_dim,
176+
"PQ subspace dimension (pq_dim) should be smaller than the dataset dimension");
177+
RAFT_EXPECTS(dim % params.pq_dim == 0,
178+
"PQ subspace dimension (pq_dim) must divide the dataset dimension");
179+
}
180+
181+
raft::device_matrix_view<float, IdxT> centers() noexcept { return centers_.view(); }
182+
183+
raft::device_matrix_view<const float, IdxT> centers() const noexcept
184+
{
185+
return raft::make_const_mdspan(centers_.view());
186+
}
187+
188+
raft::device_vector_view<uint32_t, IdxT> labels() noexcept { return labels_.view(); }
189+
190+
raft::device_vector_view<const uint32_t, IdxT> labels() const noexcept
191+
{
192+
return raft::make_const_mdspan(labels_.view());
193+
}
194+
195+
raft::device_vector_view<uint32_t, IdxT> soar_labels() noexcept { return soar_labels_.view(); }
196+
197+
raft::device_vector_view<const uint32_t, IdxT> soar_labels() const noexcept
198+
{
199+
return raft::make_const_mdspan(soar_labels_.view());
200+
}
201+
202+
uint32_t n_rows() const noexcept { return n_rows_; }
203+
204+
uint32_t n_leaves() const noexcept { return n_leaves_; }
205+
206+
uint32_t pq_dim() const noexcept { return pq_dim_; }
207+
208+
raft::device_matrix_view<const float, uint32_t, raft::row_major> pq_codebook() const noexcept
209+
{
210+
return raft::make_const_mdspan(pq_codebook_.view());
211+
}
212+
213+
raft::device_matrix_view<float, uint32_t, raft::row_major> pq_codebook() noexcept
214+
{
215+
return pq_codebook_.view();
216+
}
217+
218+
raft::host_matrix_view<const uint8_t, IdxT, raft::row_major> quantized_residuals() const noexcept
219+
{
220+
return raft::make_const_mdspan(quantized_residuals_.view());
221+
}
222+
223+
raft::host_matrix_view<uint8_t, IdxT, raft::row_major> quantized_residuals() noexcept
224+
{
225+
return quantized_residuals_.view();
226+
}
227+
228+
raft::host_matrix_view<const uint8_t, IdxT, raft::row_major> quantized_soar_residuals()
229+
const noexcept
230+
{
231+
return raft::make_const_mdspan(quantized_soar_residuals_.view());
232+
}
233+
234+
raft::host_matrix_view<uint8_t, IdxT, raft::row_major> quantized_soar_residuals() noexcept
235+
{
236+
return quantized_soar_residuals_.view();
237+
}
238+
239+
raft::host_matrix_view<int16_t, IdxT, raft::row_major> bf16_dataset() noexcept
240+
{
241+
return bf16_dataset_.view();
242+
}
243+
244+
raft::host_matrix_view<const int16_t, IdxT, raft::row_major> bf16_dataset() const noexcept
245+
{
246+
return raft::make_const_mdspan(bf16_dataset_.view());
247+
}
248+
249+
private:
250+
cuvs::distance::DistanceType metric_;
251+
IdxT dim_;
252+
IdxT n_rows_;
253+
uint32_t pq_dim_;
254+
uint32_t pq_bits_;
255+
uint32_t n_leaves_;
256+
257+
raft::device_matrix<float, IdxT, raft::row_major> centers_;
258+
raft::device_vector<uint32_t, IdxT> labels_;
259+
raft::device_vector<uint32_t, IdxT> soar_labels_;
260+
raft::device_matrix<float, uint32_t, raft::row_major> pq_codebook_;
261+
raft::host_matrix<uint8_t, IdxT, raft::row_major> quantized_residuals_;
262+
raft::host_matrix<uint8_t, IdxT, raft::row_major> quantized_soar_residuals_;
263+
raft::host_matrix<int16_t, IdxT, raft::row_major> bf16_dataset_;
264+
// TODO - add any data, pointers or structures needed
265+
};
266+
/**
267+
* @}
268+
*/
269+
270+
/**
271+
* @defgroup scann_cpp_index_build ScaNN index build functions
272+
* @{
273+
*/
274+
/**
275+
* @brief Build the index from the dataset for efficient search.
276+
*
277+
*/
278+
auto build(raft::resources const& handle,
279+
const cuvs::neighbors::experimental::scann::index_params& params,
280+
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset)
281+
-> cuvs::neighbors::experimental::scann::index<float, int64_t>;
282+
283+
auto build(raft::resources const& handle,
284+
const cuvs::neighbors::experimental::scann::index_params& params,
285+
raft::host_matrix_view<const float, int64_t, raft::row_major> dataset)
286+
-> cuvs::neighbors::experimental::scann::index<float, int64_t>;
287+
/**
288+
* @defgroup scann_cpp_serialize ScaNN serialize functions
289+
* @{
290+
*/
291+
/**
292+
* @brief Save the index to files in a directory
293+
*
294+
* This serializes the index into a list of files for integration into
295+
* OSS ScaNN for use with search
296+
*
297+
* NOTE: the implementation of ScaNN index build is EXPERIMENTAL and currently
298+
* not subject to comprehensive, automated testing. Accuracy and performance
299+
* are not guaranteed, and could diverge without warning.
300+
*
301+
*/
302+
303+
void serialize(raft::resources const& handle,
304+
const std::string& file_prefix,
305+
const cuvs::neighbors::experimental::scann::index<float, int64_t>& index);
306+
307+
/**
308+
* @}
309+
*/
310+
311+
} // namespace cuvs::neighbors::experimental::scann

0 commit comments

Comments
 (0)