From 05da99744c74bf9fe3d71d53e1d12141e3a1a456 Mon Sep 17 00:00:00 2001 From: achirkin Date: Thu, 2 Jul 2026 08:45:08 +0200 Subject: [PATCH] Revert cagra-hnswlib wrapper refactoring to allow overriding heuristics params --- cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu | 36 +++++++++++++++++-- .../ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h | 35 ++++++++++++++++-- 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu index c903b39fcc..a784b86aa1 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ @@ -40,8 +40,38 @@ auto parse_build_param(const nlohmann::json& conf) -> // Reuse the CAGRA wrapper params parser ::parse_build_param(conf, cagra_params); - - if (conf.contains("M")) { hnsw_params.M = conf.at("M"); } + // If the users provides parameter M, we can use the CAGRA-HNSW heuristics to find optimal + // parameters for the dataset and HNSW reference. + if (conf.contains("M")) { + // Postpone the parsing of the CAGRA build params until the dataset extents are known. + // We the default parameters depend on the dataset extents; and we still would like to be able + // to override them. + cagra_params.cagra_params = [conf, hnsw_params](raft::matrix_extent extents, + cuvs::distance::DistanceType dist_type) { + auto ps = cuvs::neighbors::cagra::index_params::from_hnsw_params( + extents, + conf.at("M"), + hnsw_params.ef_construction, + cuvs::neighbors::cagra::hnsw_heuristic_type::SAME_GRAPH_FOOTPRINT, + dist_type); + ps.metric = dist_type; + // Parse ACE parameters if provided + if (conf.contains("npartitions") || conf.contains("build_dir") || + conf.contains("ef_construction") || conf.contains("use_disk")) { + auto ace_params = cuvs::neighbors::cagra::graph_build_params::ace_params(); + if (conf.contains("npartitions")) { ace_params.npartitions = conf.at("npartitions"); } + if (conf.contains("build_dir")) { ace_params.build_dir = conf.at("build_dir"); } + if (conf.contains("ef_construction")) { + ace_params.ef_construction = conf.at("ef_construction"); + } + if (conf.contains("use_disk")) { ace_params.use_disk = conf.at("use_disk"); } + ps.graph_build_params = ace_params; + } + // NB: above, we only provide the defaults. Below we parse the explicit parameters as usual. + ::parse_build_param(conf, ps); + return ps; + }; + } return param; } diff --git a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h index db618f6559..7583fdd3f0 100644 --- a/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h +++ b/cpp/bench/ann/src/cuvs/cuvs_cagra_hnswlib_wrapper.h @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -85,9 +85,38 @@ void cuvs_cagra_hnswlib::build(const T* dataset, size_t nrow) // when the data set is on host, we can pass it directly to HNSW bool dataset_is_on_host = raft::get_device_for_address(dataset) == -1; - auto dataset_view = raft::make_host_matrix_view(dataset, nrow, this->dim_); + // re-use the CAGRA wrapper to parse build params + auto bps = build_param_.cagra_build_params; + // Not very conveniently, the CAGRA wrapper resolves parameters after the dataset shape is known, + // so it takes a lambda to do it. Even though we know the shape, we want to use the wrapper as-is, + // so we just modify that lambda. + bps.cagra_params = [dataset_is_on_host, orig_cagra_params = bps.cagra_params]( + auto dataset_extents, auto metric) { + auto params = orig_cagra_params(dataset_extents, metric); + params.attach_dataset_on_build = !dataset_is_on_host; + return params; + }; + cuvs_cagra cagra_wrapper{this->metric_, this->dim_, bps}; + + // build the CAGRA index + cagra_wrapper.build(dataset, nrow); + auto& cagra_index = *cagra_wrapper.get_index(); + + // pass the dataset directly to HNSW if it's on the host + std::optional> opt_dataset_view = std::nullopt; + if (dataset_is_on_host) { + opt_dataset_view.emplace( + raft::make_host_matrix_view(dataset, nrow, this->dim_)); + } + // convert the index to HNSW format - hnsw_index_ = cuvs::neighbors::hnsw::build(handle_, build_param_.hnsw_index_params, dataset_view); + hnsw_index_ = cuvs::neighbors::hnsw::from_cagra( + handle_, build_param_.hnsw_index_params, cagra_index, opt_dataset_view); + + // special treatment in save/serialize step + if (cagra_index.dataset_fd().has_value() && cagra_index.graph_fd().has_value()) { + cagra_ace_build_ = true; + } } template