Skip to content

Commit 88f7f23

Browse files
authored
CosineExpanded Distance Metric for CAGRA (rapidsai#197)
Currently only IVF-PQ can be used as the graph building algorithm (NN Descent does not support Cosine). As a result, we are limited by IVF-PQ's restriction of data to be of float / half type for the Cosine metric. This PR also fixes an in-place data modification that was being done by IVF-PQ. Opportunities for optimization: NN Descent to support Cosine and compute dataset norms only once -- during NN Descent. Re-use those for CAGRA. [UPDATE 08/21/2025]: NN Descent now support Cosine. This PR allows the initial CAGRA graph to be built by both methods -- IVF_PQ, NN_DESCENT. The IVF_PQ restriction on data types holds, but uint8 and int8 can be supported with NN Descent as the graph building algorithm. ITERATIVE CAGRA SEARCH is currently disabled for Cosine. [UPDATE 09/23/2025]: This PR also adds Cosine support for IVF_PQ with uint8 / int8 inputs. The above mentioned restriction with IVF_PQ has been removed. So with this PR CAGRA supports Cosine wholly, for float, uint8 and int8 inputs. ITERATIVE_SEARCH however still has some issues as the graph building method with the Cosine metric and has been disabled. [UPDATE 09/25/2025]: Binary size comparison for libcuvs.so (CUDA 12.9, x86): branch-25.10: 1154.42 MB This PR: 1160.73 MB Total CAGRA testing time: branch-25.10: ``` Start 10: NEIGHBORS_ANN_CAGRA_FLOAT_UINT32_TEST 19/37 Test rapidsai#10: NEIGHBORS_ANN_CAGRA_FLOAT_UINT32_TEST ... Passed 825.43 sec Start 11: NEIGHBORS_ANN_CAGRA_HELPERS_TEST 20/37 Test rapidsai#11: NEIGHBORS_ANN_CAGRA_HELPERS_TEST ........ Passed 0.58 sec Start 12: NEIGHBORS_ANN_CAGRA_HALF_UINT32_TEST 21/37 Test rapidsai#12: NEIGHBORS_ANN_CAGRA_HALF_UINT32_TEST .... Passed 663.97 sec Start 13: NEIGHBORS_ANN_CAGRA_INT8_UINT32_TEST 22/37 Test rapidsai#13: NEIGHBORS_ANN_CAGRA_INT8_UINT32_TEST .... Passed 397.57 sec Start 14: NEIGHBORS_ANN_CAGRA_UINT8_UINT32_TEST 23/37 Test rapidsai#14: NEIGHBORS_ANN_CAGRA_UINT8_UINT32_TEST ... Passed 408.16 sec ``` This PR: ``` Start 10: NEIGHBORS_ANN_CAGRA_FLOAT_UINT32_TEST 19/37 Test rapidsai#10: NEIGHBORS_ANN_CAGRA_FLOAT_UINT32_TEST ... Passed 1830.34 sec Start 11: NEIGHBORS_ANN_CAGRA_HELPERS_TEST 20/37 Test rapidsai#11: NEIGHBORS_ANN_CAGRA_HELPERS_TEST ........ Passed 0.45 sec Start 12: NEIGHBORS_ANN_CAGRA_HALF_UINT32_TEST 21/37 Test rapidsai#12: NEIGHBORS_ANN_CAGRA_HALF_UINT32_TEST .... Passed 1444.14 sec Start 13: NEIGHBORS_ANN_CAGRA_INT8_UINT32_TEST 22/37 Test rapidsai#13: NEIGHBORS_ANN_CAGRA_INT8_UINT32_TEST .... Passed 973.64 sec Start 14: NEIGHBORS_ANN_CAGRA_UINT8_UINT32_TEST 23/37 Test rapidsai#14: NEIGHBORS_ANN_CAGRA_UINT8_UINT32_TEST ... Passed 1010.46 sec ``` [UPDATE 09/30/2025]: Updates to CAGRA C++ tests according to the latest PR reviews. New total CAGRA testing time: branch-25.10: ``` Start 9: NEIGHBORS_ANN_CAGRA_TEST_BUGS 18/37 Test rapidsai#9: NEIGHBORS_ANN_CAGRA_TEST_BUGS ........... Passed 16.99 sec Start 10: NEIGHBORS_ANN_CAGRA_FLOAT_UINT32_TEST 19/37 Test rapidsai#10: NEIGHBORS_ANN_CAGRA_FLOAT_UINT32_TEST ... Passed 803.64 sec Start 11: NEIGHBORS_ANN_CAGRA_HELPERS_TEST 20/37 Test rapidsai#11: NEIGHBORS_ANN_CAGRA_HELPERS_TEST ........ Passed 0.49 sec Start 12: NEIGHBORS_ANN_CAGRA_HALF_UINT32_TEST 21/37 Test rapidsai#12: NEIGHBORS_ANN_CAGRA_HALF_UINT32_TEST .... Passed 667.89 sec Start 13: NEIGHBORS_ANN_CAGRA_INT8_UINT32_TEST 22/37 Test rapidsai#13: NEIGHBORS_ANN_CAGRA_INT8_UINT32_TEST .... Passed 420.49 sec Start 14: NEIGHBORS_ANN_CAGRA_UINT8_UINT32_TEST 23/37 Test rapidsai#14: NEIGHBORS_ANN_CAGRA_UINT8_UINT32_TEST ... Passed 429.57 sec ``` This PR: ``` Start 9: NEIGHBORS_ANN_CAGRA_TEST_BUGS 18/37 Test rapidsai#9: NEIGHBORS_ANN_CAGRA_TEST_BUGS ........... Passed 26.62 sec Start 10: NEIGHBORS_ANN_CAGRA_FLOAT_UINT32_TEST 19/37 Test rapidsai#10: NEIGHBORS_ANN_CAGRA_FLOAT_UINT32_TEST ... Passed 973.23 sec Start 11: NEIGHBORS_ANN_CAGRA_HELPERS_TEST 20/37 Test rapidsai#11: NEIGHBORS_ANN_CAGRA_HELPERS_TEST ........ Passed 0.43 sec Start 12: NEIGHBORS_ANN_CAGRA_HALF_UINT32_TEST 21/37 Test rapidsai#12: NEIGHBORS_ANN_CAGRA_HALF_UINT32_TEST .... Passed 702.02 sec Start 13: NEIGHBORS_ANN_CAGRA_INT8_UINT32_TEST 22/37 Test rapidsai#13: NEIGHBORS_ANN_CAGRA_INT8_UINT32_TEST .... Passed 491.65 sec Start 14: NEIGHBORS_ANN_CAGRA_UINT8_UINT32_TEST 23/37 Test rapidsai#14: NEIGHBORS_ANN_CAGRA_UINT8_UINT32_TEST ... Passed 541.43 sec ``` Fixes rapidsai#1288 Fixes rapidsai#389 Authors: - Tarang Jain (https://github.com/tarang-jain) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#197
1 parent 5496b90 commit 88f7f23

33 files changed

Lines changed: 1033 additions & 169 deletions

cpp/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ if(NOT BUILD_CPU_ONLY)
231231
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim128_t8.cu
232232
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim256_t16.cu
233233
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_float_uint32_dim512_t32.cu
234+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_float_uint32_dim128_t8.cu
235+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_float_uint32_dim256_t16.cu
236+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_float_uint32_dim512_t32.cu
234237
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint32_dim128_t8.cu
235238
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint32_dim256_t16.cu
236239
src/neighbors/detail/cagra/compute_distance_standard_InnerProduct_half_uint32_dim512_t32.cu
@@ -246,6 +249,15 @@ if(NOT BUILD_CPU_ONLY)
246249
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim128_t8.cu
247250
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim256_t16.cu
248251
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_half_uint32_dim512_t32.cu
252+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_half_uint32_dim128_t8.cu
253+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_half_uint32_dim256_t16.cu
254+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_half_uint32_dim512_t32.cu
255+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_int8_uint32_dim128_t8.cu
256+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_int8_uint32_dim256_t16.cu
257+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_int8_uint32_dim512_t32.cu
258+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_uint8_uint32_dim128_t8.cu
259+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_uint8_uint32_dim256_t16.cu
260+
src/neighbors/detail/cagra/compute_distance_standard_CosineExpanded_uint8_uint32_dim512_t32.cu
249261
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim128_t8.cu
250262
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim256_t16.cu
251263
src/neighbors/detail/cagra/compute_distance_standard_L2Expanded_int8_uint32_dim512_t32.cu

cpp/include/cuvs/neighbors/cagra.hpp

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,14 @@ struct index : cuvs::neighbors::index {
339339
return graph_view_;
340340
}
341341

342+
/** Dataset norms for cosine distance [size] */
343+
[[nodiscard]] inline auto dataset_norms() const noexcept
344+
-> std::optional<raft::device_vector_view<const float, int64_t>>
345+
{
346+
if (dataset_norms_.has_value()) { return raft::make_const_mdspan(dataset_norms_->view()); }
347+
return std::nullopt;
348+
}
349+
342350
// Don't allow copying the index for performance reasons (try avoiding copying data)
343351
/** \cond */
344352
index(const index&) = delete;
@@ -354,7 +362,8 @@ struct index : cuvs::neighbors::index {
354362
: cuvs::neighbors::index(),
355363
metric_(metric),
356364
graph_(raft::make_device_matrix<IdxT, int64_t>(res, 0, 0)),
357-
dataset_(new cuvs::neighbors::empty_dataset<int64_t>(0))
365+
dataset_(new cuvs::neighbors::empty_dataset<int64_t>(0)),
366+
dataset_norms_(std::nullopt)
358367
{
359368
}
360369

@@ -420,12 +429,21 @@ struct index : cuvs::neighbors::index {
420429
: cuvs::neighbors::index(),
421430
metric_(metric),
422431
graph_(raft::make_device_matrix<IdxT, int64_t>(res, 0, 0)),
423-
dataset_(make_aligned_dataset(res, dataset, 16))
432+
dataset_(make_aligned_dataset(res, dataset, 16)),
433+
dataset_norms_(std::nullopt)
424434
{
425435
RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0),
426436
"Dataset and knn_graph must have equal number of rows");
427437
update_graph(res, knn_graph);
428438

439+
if (metric_ == cuvs::distance::DistanceType::CosineExpanded) {
440+
auto p = dynamic_cast<strided_dataset<T, int64_t>*>(dataset_.get());
441+
if (p) {
442+
auto dataset_view = p->view();
443+
if (dataset_view.extent(0) > 0) { compute_dataset_norms_(res); }
444+
}
445+
}
446+
429447
raft::resource::sync_stream(res);
430448
}
431449

@@ -435,48 +453,81 @@ struct index : cuvs::neighbors::index {
435453
* If the new dataset rows are aligned on 16 bytes, then only a reference is stored to the
436454
* dataset. It is the caller's responsibility to ensure that dataset stays alive as long as the
437455
* index. It is expected that the same set of vectors are used for update_dataset and index build.
456+
*
457+
* Note: This will clear any precomputed dataset norms.
438458
*/
439459
void update_dataset(raft::resources const& res,
440460
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset)
441461
{
442462
dataset_ = make_aligned_dataset(res, dataset, 16);
463+
dataset_norms_.reset();
464+
465+
if (metric() == cuvs::distance::DistanceType::CosineExpanded) {
466+
if (dataset.extent(0) > 0) { compute_dataset_norms_(res); }
467+
}
443468
}
444469

445470
/** Set the dataset reference explicitly to a device matrix view with padding. */
446471
void update_dataset(raft::resources const& res,
447472
raft::device_matrix_view<const T, int64_t, raft::layout_stride> dataset)
448473
{
449474
dataset_ = make_aligned_dataset(res, dataset, 16);
475+
dataset_norms_.reset();
476+
477+
if (metric() == cuvs::distance::DistanceType::CosineExpanded) {
478+
if (dataset.extent(0) > 0) { compute_dataset_norms_(res); }
479+
}
450480
}
451481

452482
/**
453483
* Replace the dataset with a new dataset.
454484
*
455485
* We create a copy of the dataset on the device. The index manages the lifetime of this copy. It
456486
* is expected that the same set of vectors are used for update_dataset and index build.
487+
*
488+
* Note: This will clear any precomputed dataset norms.
457489
*/
458490
void update_dataset(raft::resources const& res,
459491
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset)
460492
{
461493
dataset_ = make_aligned_dataset(res, dataset, 16);
494+
dataset_norms_.reset();
495+
if (metric() == cuvs::distance::DistanceType::CosineExpanded) {
496+
if (dataset.extent(0) > 0) { compute_dataset_norms_(res); }
497+
}
462498
}
463499

464500
/**
465501
* Replace the dataset with a new dataset. It is expected that the same set of vectors are used
466502
* for update_dataset and index build.
503+
*
504+
* Note: This will clear any precomputed dataset norms.
467505
*/
468506
template <typename DatasetT>
469507
auto update_dataset(raft::resources const& res, DatasetT&& dataset)
470508
-> std::enable_if_t<std::is_base_of_v<cuvs::neighbors::dataset<dataset_index_type>, DatasetT>>
471509
{
472510
dataset_ = std::make_unique<DatasetT>(std::move(dataset));
511+
dataset_norms_.reset();
512+
if (metric() == cuvs::distance::DistanceType::CosineExpanded) {
513+
auto p = dynamic_cast<strided_dataset<T, int64_t>*>(dataset_.get());
514+
if (p) {
515+
auto dataset_view = p->view();
516+
if (dataset_view.extent(0) > 0) { compute_dataset_norms_(res); }
517+
}
518+
}
473519
}
474520

475521
template <typename DatasetT>
476522
auto update_dataset(raft::resources const& res, std::unique_ptr<DatasetT>&& dataset)
477523
-> std::enable_if_t<std::is_base_of_v<neighbors::dataset<dataset_index_type>, DatasetT>>
478524
{
479525
dataset_ = std::move(dataset);
526+
dataset_norms_.reset();
527+
if (metric() == cuvs::distance::DistanceType::CosineExpanded) {
528+
auto dataset_view = this->dataset();
529+
if (dataset_view.extent(0) > 0) { compute_dataset_norms_(res); }
530+
}
480531
}
481532

482533
/**
@@ -519,6 +570,10 @@ struct index : cuvs::neighbors::index {
519570
raft::device_matrix<IdxT, int64_t, raft::row_major> graph_;
520571
raft::device_matrix_view<const IdxT, int64_t, raft::row_major> graph_view_;
521572
std::unique_ptr<neighbors::dataset<dataset_index_type>> dataset_;
573+
// only float distances supported at the moment
574+
std::optional<raft::device_vector<float, int64_t>> dataset_norms_;
575+
576+
void compute_dataset_norms_(raft::resources const& res);
522577
};
523578
/**
524579
* @}
@@ -539,6 +594,7 @@ struct index : cuvs::neighbors::index {
539594
* The following distance metrics are supported:
540595
* - L2
541596
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
597+
* - CosineExpanded
542598
*
543599
* Usage example:
544600
* @code{.cpp}
@@ -576,6 +632,7 @@ auto build(raft::resources const& res,
576632
* The following distance metrics are supported:
577633
* - L2
578634
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
635+
* - CosineExpanded
579636
*
580637
* Usage example:
581638
* @code{.cpp}
@@ -613,6 +670,7 @@ auto build(raft::resources const& res,
613670
* The following distance metrics are supported:
614671
* - L2
615672
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
673+
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
616674
*
617675
* Usage example:
618676
* @code{.cpp}
@@ -649,6 +707,7 @@ auto build(raft::resources const& res,
649707
*
650708
* The following distance metrics are supported:
651709
* - L2
710+
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
652711
*
653712
* Usage example:
654713
* @code{.cpp}
@@ -685,6 +744,7 @@ auto build(raft::resources const& res,
685744
*
686745
* The following distance metrics are supported:
687746
* - L2
747+
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
688748
*
689749
* Usage example:
690750
* @code{.cpp}
@@ -722,6 +782,7 @@ auto build(raft::resources const& res,
722782
* The following distance metrics are supported:
723783
* - L2
724784
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
785+
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
725786
*
726787
* Usage example:
727788
* @code{.cpp}
@@ -759,6 +820,7 @@ auto build(raft::resources const& res,
759820
* The following distance metrics are supported:
760821
* - L2
761822
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
823+
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
762824
*
763825
* Usage example:
764826
* @code{.cpp}
@@ -796,6 +858,7 @@ auto build(raft::resources const& res,
796858
* The following distance metrics are supported:
797859
* - L2
798860
* - InnerProduct (currently only supported with IVF-PQ as the build algorithm)
861+
* - CosineExpanded (dataset norms are computed as float regardless of input data type)
799862
*
800863
* Usage example:
801864
* @code{.cpp}

cpp/src/neighbors/cagra.cuh

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -22,10 +22,12 @@
2222
#include "detail/cagra/cagra_search.cuh"
2323
#include "detail/cagra/graph_core.cuh"
2424

25+
#include "detail/ann_utils.cuh"
2526
#include <raft/core/device_mdspan.hpp>
2627
#include <raft/core/host_device_accessor.hpp>
2728
#include <raft/core/mdspan.hpp>
2829
#include <raft/core/resources.hpp>
30+
#include <raft/linalg/norm.cuh>
2931

3032
#include <cuvs/distance/distance.hpp>
3133
#include <cuvs/neighbors/cagra.hpp>
@@ -36,6 +38,37 @@
3638

3739
namespace cuvs::neighbors::cagra {
3840

41+
// Member function implementations for cagra::index
42+
template <typename T, typename IdxT>
43+
void index<T, IdxT>::compute_dataset_norms_(raft::resources const& res)
44+
{
45+
// Get the dataset view
46+
auto dataset_view = this->dataset();
47+
48+
// Allocate norms vector if not already allocated
49+
if (!dataset_norms_.has_value() || dataset_norms_->extent(0) != dataset_view.extent(0)) {
50+
dataset_norms_.reset();
51+
dataset_norms_ = raft::make_device_vector<float, int64_t>(res, dataset_view.extent(0));
52+
}
53+
54+
constexpr float kScale = cuvs::spatial::knn::detail::utils::config<T>::kDivisor /
55+
cuvs::spatial::knn::detail::utils::config<float>::kDivisor;
56+
57+
// first scale the dataset and then compute norms
58+
auto scaled_sq_op = raft::compose_op(
59+
raft::sq_op{}, raft::div_const_op<float>{float(kScale)}, raft::cast_op<float>());
60+
raft::linalg::reduce<true, true, T, float, int64_t>(dataset_norms_->data_handle(),
61+
dataset_view.data_handle(),
62+
dataset_view.stride(0),
63+
dataset_view.extent(0),
64+
(float)0,
65+
raft::resource::get_cuda_stream(res),
66+
false,
67+
scaled_sq_op,
68+
raft::add_op(),
69+
raft::sqrt_op{});
70+
}
71+
3972
/**
4073
* @defgroup cagra CUDA ANN Graph-based nearest neighbor search
4174
* @{

cpp/src/neighbors/detail/cagra/cagra_build.cuh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,9 @@ void build_knn_graph(
127127
cuvs::neighbors::cagra::graph_build_params::ivf_pq_params pq)
128128
{
129129
RAFT_EXPECTS(pq.build_params.metric == cuvs::distance::DistanceType::L2Expanded ||
130-
pq.build_params.metric == cuvs::distance::DistanceType::InnerProduct,
131-
"Currently only L2Expanded or InnerProduct metric are supported");
130+
pq.build_params.metric == cuvs::distance::DistanceType::InnerProduct ||
131+
pq.build_params.metric == cuvs::distance::DistanceType::CosineExpanded,
132+
"Currently only L2Expanded, InnerProduct and CosineExpanded metrics are supported");
132133

133134
uint32_t node_degree = knn_graph.extent(1);
134135
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope(
@@ -710,6 +711,11 @@ index<T, IdxT> build(
710711
std::holds_alternative<cagra::graph_build_params::nn_descent_params>(knn_build_params),
711712
"IVF_PQ for CAGRA graph build does not support BitwiseHamming as a metric. Please "
712713
"use nn-descent or the iterative CAGRA search build.");
714+
RAFT_EXPECTS(
715+
params.metric != cuvs::distance::DistanceType::CosineExpanded ||
716+
std::holds_alternative<cagra::graph_build_params::ivf_pq_params>(knn_build_params) ||
717+
std::holds_alternative<cagra::graph_build_params::nn_descent_params>(knn_build_params),
718+
"CosineExpanded distance is not supported for iterative CAGRA graph build.");
713719

714720
// Validate data type for BitwiseHamming metric
715721
RAFT_EXPECTS(params.metric != cuvs::distance::DistanceType::BitwiseHamming ||

0 commit comments

Comments
 (0)