diff --git a/CMakeLists.txt b/CMakeLists.txt index b1661249d..70945bf0d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,9 @@ -cmake_minimum_required(VERSION 3.10) +option(USE_CUDA "Build Cuda code" OFF) +if(USE_CUDA) + cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR) +else() + cmake_minimum_required(VERSION 3.10) +endif() cmake_policy(SET CMP0077 NEW) set(CMAKE_CXX_STANDARD 20) @@ -24,10 +29,30 @@ include(cmake/san.cmake) # ---------------------------------------------------------------------------------------------- project(VectorSimilarity) +if (USE_CUDA) + # List of architectures to generate device code + set(CMAKE_CUDA_ARCHITECTURES "native") + # Enable CUDA compilation for this project + enable_language(CUDA) + # Add definition for conditional compilation of CUDA components + add_definitions(-DUSE_CUDA) + # Perform all RAFT-specific CMake setup + include(cmake/raft.cmake) + # Required flags for compiling RAFT + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr") + set(CMAKE_CUDA_FLAGS_RELEASE "-O3") + set(CMAKE_CUDA_FLAGS_DEBUG "-g") + set(CMAKE_CUDA_STANDARD 17) + if(${CUDAToolkit_VERSION_MAJOR} GREATER 10) + # cuda11 support --threads for compile some large .cu more efficient + add_compile_options($<$:--threads=4>) + endif() +endif() + # Only do these if this is the main project, and not if it is included through add_subdirectory set_property(GLOBAL PROPERTY USE_FOLDERS ON) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions -fPIC ${CLANG_SAN_FLAGS} ${LLVM_CXX_FLAGS} ${COV_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions -fPIC -pthread ${CLANG_SAN_FLAGS} ${LLVM_CXX_FLAGS} ${COV_CXX_FLAGS} -lrt") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_LINKER_FLAGS} ${LLVM_LD_FLAGS}") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_LINKER_FLAGS} ${LLVM_LD_FLAGS}") diff --git a/check-format.sh b/check-format.sh index b474077c4..762fd519e 100755 --- a/check-format.sh +++ b/check-format.sh @@ -1,6 +1,6 @@ #!/bin/bash -CLANG_FMT_SRCS=$(find ./src/ \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hh' -o -name '*.hpp' \)) -CLANG_FMT_TESTS="$(find ./tests/ -type d \( -path ./tests/unit/build \) -prune -false -o \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hh' -o -name '*.hpp' \))" +CLANG_FMT_SRCS=$(find ./src/ \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hh' -o -name '*.hpp' -o -name '*.cuh' -o -name '*.cu' \)) +CLANG_FMT_TESTS="$(find ./tests/ -type d \( -path ./tests/unit/build \) -prune -false -o \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.h' -o -name '*.hh' -o -name '*.hpp' -o -name '*.cuh' -o -name '*.cu' \))" E=0 for filename in $CLANG_FMT_SRCS $CLANG_FMT_TESTS; do diff --git a/cmake/raft.cmake b/cmake/raft.cmake new file mode 100644 index 000000000..b9d1c15cd --- /dev/null +++ b/cmake/raft.cmake @@ -0,0 +1,81 @@ +if(USE_CUDA) + # Set which version of RAPIDS to use + set(RAPIDS_VERSION 23.12) + # Set which version of RAFT to use (defined separately for testing + # minimal dependency changes if necessary) + set(RAFT_VERSION "${RAPIDS_VERSION}") + set(RAFT_FORK "rapidsai") + set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + + # Download CMake file for bootstrapping RAPIDS-CMake, a utility that + # simplifies handling of complex RAPIDS dependencies + if (NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake + ${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) + endif() + include(${CMAKE_CURRENT_BINARY_DIR}/RAPIDS.cmake) + + # General tool for orchestrating RAPIDS dependencies + include(rapids-cmake) + # CPM helper functions with dependency tracking + include(rapids-cpm) + rapids_cpm_init() + # Common CMake CUDA logic + include(rapids-cuda) + # Include required dependencies in Project-Config.cmake modules + # include(rapids-export) TODO(wphicks) + # Functions to find system dependencies with dependency tracking + include(rapids-find) + + # Correctly handle supported CUDA architectures + # (From rapids-cuda) + rapids_cuda_init_architectures(VectorSimilarity) + + # Find system CUDA toolkit + rapids_find_package(CUDAToolkit REQUIRED) + + set(RAFT_VERSION "${RAPIDS_VERSION}") + set(RAFT_FORK "rapidsai") + set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + + function(find_and_configure_raft) + set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + set(RAFT_COMPONENTS "") + if(PKG_COMPILE_LIBRARY) + string(APPEND RAFT_COMPONENTS " compiled") + endif() + # Invoke CPM find_package() + # (From rapids-cpm) + rapids_cpm_find(raft ${PKG_VERSION} + GLOBAL_TARGETS raft::raft + BUILD_EXPORT_SET VectorSimilarity-exports + INSTALL_EXPORT_SET VectorSimilarity-exports + COMPONENTS ${RAFT_COMPONENTS} + CPM_ARGS + GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git + GIT_TAG ${PKG_PINNED_TAG} + SOURCE_SUBDIR cpp + OPTIONS + "BUILD_TESTS OFF" + "BUILD_BENCH OFF" + "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" + ) + if(raft_ADDED) + message(VERBOSE "VectorSimilarity: Using RAFT located in ${raft_SOURCE_DIR}") + else() + message(VERBOSE "VectorSimilarity: Using RAFT located in ${raft_DIR}") + endif() + endfunction() + + # Change pinned tag here to test a commit in CI + # To use a different RAFT locally, set the CMake variable + # CPM_raft_SOURCE=/path/to/local/raft + find_and_configure_raft(VERSION ${RAFT_VERSION}.00 + FORK ${RAFT_FORK} + PINNED_TAG ${RAFT_PINNED_TAG} + COMPILE_LIBRARY OFF + ) +endif() diff --git a/src/VecSim/CMakeLists.txt b/src/VecSim/CMakeLists.txt index e55f4d703..46d89cbc3 100644 --- a/src/VecSim/CMakeLists.txt +++ b/src/VecSim/CMakeLists.txt @@ -19,6 +19,8 @@ add_library(VectorSimilarity ${VECSIM_LIBTYPE} index_factories/hnsw_factory.cpp index_factories/tiered_factory.cpp index_factories/index_factory.cpp + $<$:index_factories/raft_ivf_factory.cu> + $<$:index_factories/raft_ivf_tiered_factory.cpp> algorithms/hnsw/visited_nodes_handler.cpp vec_sim.cpp vec_sim_interface.cpp @@ -31,9 +33,19 @@ add_library(VectorSimilarity ${VECSIM_LIBTYPE} ${HEADER_LIST} ) -target_link_libraries(VectorSimilarity VectorSimilaritySpaces) if(VECSIM_BUILD_TESTS) add_library(VectorSimilaritySerializer utils/serializer.cpp) - target_link_libraries(VectorSimilarity VectorSimilaritySerializer) endif() + +target_link_libraries(VectorSimilarity +PUBLIC + VectorSimilaritySpaces + $<$:VectorSimilaritySerializer> +PRIVATE + $<$:raft::raft> + $<$:CUDA::cusolver> + $<$:CUDA::cublas> + $<$:CUDA::curand> + $<$:CUDA::cusparse> +) diff --git a/src/VecSim/algorithms/brute_force/brute_force.h b/src/VecSim/algorithms/brute_force/brute_force.h index 43399bacd..3e223320d 100644 --- a/src/VecSim/algorithms/brute_force/brute_force.h +++ b/src/VecSim/algorithms/brute_force/brute_force.h @@ -35,6 +35,7 @@ class BruteForceIndex : public VecSimIndexAbstract { public: BruteForceIndex(const BFParams *params, const AbstractIndexInitParams &abstractInitParams); + virtual void clear() = 0; size_t indexSize() const override; size_t indexCapacity() const override; vecsim_stl::vector computeBlockScores(const DataBlock &block, const void *queryBlob, @@ -54,6 +55,7 @@ class BruteForceIndex : public VecSimIndexAbstract { VecSimQueryParams *queryParams) const override; bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const override; inline labelType getVectorLabel(idType id) const { return idToLabelMapping.at(id); } + inline vecsim_stl::vector getLabels() const { return idToLabelMapping; } inline const vecsim_stl::vector &getVectorBlocks() const { return vectorBlocks; } inline const labelType getLabelByInternalId(idType internal_id) const { diff --git a/src/VecSim/algorithms/brute_force/brute_force_multi.h b/src/VecSim/algorithms/brute_force/brute_force_multi.h index 086adc13e..6933c4c20 100644 --- a/src/VecSim/algorithms/brute_force/brute_force_multi.h +++ b/src/VecSim/algorithms/brute_force/brute_force_multi.h @@ -23,6 +23,14 @@ class BruteForceIndex_Multi : public BruteForceIndex { ~BruteForceIndex_Multi() {} + void clear() override { + this->labelToIdsLookup.clear(); + this->idToLabelMapping.clear(); + this->idToLabelMapping.shrink_to_fit(); + this->vectorBlocks.clear(); + this->vectorBlocks.shrink_to_fit(); + this->count = idType{}; + } int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override; int deleteVector(labelType labelType) override; int deleteVectorById(labelType label, idType id) override; diff --git a/src/VecSim/algorithms/brute_force/brute_force_single.h b/src/VecSim/algorithms/brute_force/brute_force_single.h index dba740c89..ea0adc3ff 100644 --- a/src/VecSim/algorithms/brute_force/brute_force_single.h +++ b/src/VecSim/algorithms/brute_force/brute_force_single.h @@ -21,6 +21,14 @@ class BruteForceIndex_Single : public BruteForceIndex { const AbstractIndexInitParams &abstractInitParams); ~BruteForceIndex_Single(); + void clear() override { + this->labelToIdLookup.clear(); + this->idToLabelMapping.clear(); + this->idToLabelMapping.shrink_to_fit(); + this->vectorBlocks.clear(); + this->vectorBlocks.shrink_to_fit(); + this->count = idType{}; + } int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override; int deleteVector(labelType label) override; int deleteVectorById(labelType label, idType id) override; diff --git a/src/VecSim/algorithms/raft_ivf/ivf.cuh b/src/VecSim/algorithms/raft_ivf/ivf.cuh new file mode 100644 index 000000000..2c36ffc56 --- /dev/null +++ b/src/VecSim/algorithms/raft_ivf/ivf.cuh @@ -0,0 +1,389 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "VecSim/vec_sim.h" +// For VecSimMetric, RaftIvfParams, labelType +#include "VecSim/vec_sim_common.h" +// For VecSimIndexAbstract +#include "VecSim/vec_sim_index.h" +#include "VecSim/query_result_definitions.h" // VecSimQueryResult VecSimQueryReply +#include "VecSim/algorithms/raft_ivf/ivf_interface.h" // RaftIvfInterface +#include "VecSim/memory/vecsim_malloc.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +inline auto constexpr GetRaftDistanceType(VecSimMetric vsm) { + auto result = raft::distance::DistanceType{}; + switch (vsm) { + case VecSimMetric_L2: + result = raft::distance::DistanceType::L2Expanded; + break; + case VecSimMetric_IP: + case VecSimMetric_Cosine: + result = raft::distance::DistanceType::InnerProduct; + break; + default: + throw raft::exception("Metric not supported"); + } + return result; +} + +inline auto constexpr GetRaftCodebookKind(RaftIVFPQCodebookKind vss_codebook) { + auto result = raft::neighbors::ivf_pq::codebook_gen{}; + switch (vss_codebook) { + case RaftIVFPQCodebookKind_PerCluster: + result = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; + break; + case RaftIVFPQCodebookKind_PerSubspace: + result = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; + break; + default: + throw raft::exception("Unexpected IVFPQ codebook kind"); + } + return result; +} + +inline auto constexpr GetCudaType(CudaType vss_type) { + auto result = cudaDataType_t{}; + switch (vss_type) { + case CUDAType_R_32F: + result = CUDA_R_32F; + break; + case CUDAType_R_16F: + result = CUDA_R_16F; + break; + case CUDAType_R_8U: + result = CUDA_R_8U; + break; + default: + throw raft::exception("Unexpected CUDA type"); + } + return result; +} + +void init_raft_resources() { + auto static init_flag = std::once_flag{}; + std::call_once(init_flag, []() { + raft::device_resources_manager::set_streams_per_device(8); // TODO: use env variable + raft::device_resources_manager::set_stream_pools_per_device(8); + // Create a memory pool with half of the available GPU memory. + raft::device_resources_manager::set_mem_pool(); + }); +} + +template +struct RaftIvfIndex : public RaftIvfInterface { + using data_type = DataType; + using dist_type = DistType; + +private: + // Allow either IVF-Flat or IVF-PQ parameters + using build_params_t = std::variant; + using search_params_t = std::variant; + using internal_idx_t = std::uint64_t; + using index_flat_t = raft::neighbors::ivf_flat::index; + using index_pq_t = raft::neighbors::ivf_pq::index; + using ann_index_t = std::variant; + +public: + RaftIvfIndex(const RaftIvfParams *raftIvfParams, const AbstractIndexInitParams &commonParams) + : RaftIvfInterface{commonParams}, + build_params_{raftIvfParams->usePQ ? build_params_t{std::in_place_index<1>} + : build_params_t{std::in_place_index<0>}}, + search_params_{raftIvfParams->usePQ ? search_params_t{std::in_place_index<1>} + : search_params_t{std::in_place_index<0>}}, + index_{std::nullopt}, deleted_indices_{std::nullopt}, numDeleted_{0}, + idToLabelLookup_{this->allocator}, labelToIdLookup_{this->allocator} { + std::visit( + [raftIvfParams](auto &&inner) { + inner.metric = GetRaftDistanceType(raftIvfParams->metric); + inner.n_lists = raftIvfParams->nLists; + inner.kmeans_n_iters = raftIvfParams->kmeans_nIters; + inner.add_data_on_build = false; + inner.kmeans_trainset_fraction = raftIvfParams->kmeans_trainsetFraction; + inner.conservative_memory_allocation = raftIvfParams->conservativeMemoryAllocation; + if constexpr (std::is_same_v) { + inner.adaptive_centers = raftIvfParams->adaptiveCenters; + } else if constexpr (std::is_same_v) { + inner.pq_bits = raftIvfParams->pqBits; + inner.pq_dim = raftIvfParams->pqDim; + inner.codebook_kind = GetRaftCodebookKind(raftIvfParams->codebookKind); + } + }, + build_params_); + std::visit( + [raftIvfParams](auto &&inner) { + inner.n_probes = raftIvfParams->nProbes; + if constexpr (std::is_same_v) { + inner.lut_dtype = GetCudaType(raftIvfParams->lutType); + inner.internal_distance_dtype = + GetCudaType(raftIvfParams->internalDistanceType); + inner.preferred_shmem_carvout = raftIvfParams->preferredShmemCarveout; + } + }, + search_params_); + + cosine_postprocess_ = raftIvfParams->metric == VecSimMetric_Cosine || raftIvfParams->metric == VecSimMetric_IP; + } + int addVector(const void *vector_data, labelType label, void *auxiliaryCtx = nullptr) override { + return addVectorBatch(vector_data, &label, 1, auxiliaryCtx); + } + int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, + void *auxiliaryCtx = nullptr) override { + const auto &res = raft::device_resources_manager::get_device_resources(); + // Allocate memory on device to hold vectors to be added + auto vector_data_gpu = + raft::make_device_matrix(res, batch_size, this->dim); + + // Copy vector data to previously allocated device buffer + raft::copy(vector_data_gpu.data_handle(), static_cast(vector_data), + this->dim * batch_size, res.get_stream()); + + // Create GPU vector to hold ids + internal_idx_t first_id = this->indexSize(); + internal_idx_t last_id = first_id + batch_size; + auto ids = raft::make_device_vector(res, batch_size); + raft::linalg::range(ids.data_handle(), first_id, last_id, res.get_stream()); + + // Build index if it does not exist, and extend it with the new vectors and their ids + if (std::holds_alternative(build_params_)) { + if (!index_) { + index_ = raft::neighbors::ivf_flat::build( + res, std::get(build_params_), + raft::make_const_mdspan(vector_data_gpu.view())); + deleted_indices_ = {raft::core::bitset(res, 0)}; + } + raft::neighbors::ivf_flat::extend(res, raft::make_const_mdspan(vector_data_gpu.view()), + {raft::make_const_mdspan(ids.view())}, + &std::get(*index_)); + } else { + if (!index_) { + index_ = raft::neighbors::ivf_pq::build( + res, std::get(build_params_), + raft::make_const_mdspan(vector_data_gpu.view())); + deleted_indices_ = {raft::core::bitset(res, 0)}; + } + raft::neighbors::ivf_pq::extend(res, raft::make_const_mdspan(vector_data_gpu.view()), + {raft::make_const_mdspan(ids.view())}, + &std::get(*index_)); + } + + // Add labels to internal idToLabelLookup_ mapping + this->idToLabelLookup_.insert(this->idToLabelLookup_.end(), label, label + batch_size); + for (auto i = 0; i < batch_size; ++i) { + this->labelToIdLookup_[label[i]] = first_id + i; + } + + // Update the size of the deleted indices bitset + deleted_indices_->resize(res, deleted_indices_->size() + batch_size); + + // Ensure that above operation has executed on device before + // returning from this function on host + res.sync_stream(); + return batch_size; + } + int deleteVector(labelType label) override { + // Check if label exists in internal labelToIdLookup_ mapping + auto search = labelToIdLookup_.find(label); + if (search == labelToIdLookup_.end()) { + return 0; + } + const auto &res = raft::device_resources_manager::get_device_resources(); + // Create GPU vector to hold ids to mark as deleted + internal_idx_t id = search->second; + auto id_gpu = raft::make_device_vector(res, 1); + raft::copy(id_gpu.data_handle(), &id, 1, res.get_stream()); + // Mark the id as deleted + deleted_indices_->set(res, raft::make_const_mdspan(id_gpu.view()), false); + + // Remove label from internal labelToIdLookup_ mapping + labelToIdLookup_.erase(search); + // Ensure that above operation has executed on device before + // returning from this function on host + res.sync_stream(); + this->numDeleted_ += 1; + return 1; + } + double getDistanceFrom_Unsafe(labelType label, const void *vector_data) const override { + assert(!"getDistanceFrom not implemented"); + return INVALID_SCORE; + } + size_t indexCapacity() const override { + assert(!"indexCapacity not implemented"); + return 0; + } + inline vecsim_stl::set getLabelsSet() const override { + vecsim_stl::set result(this->allocator); + for (auto const &pair : labelToIdLookup_) { + result.insert(pair.first); + } + return result; + } + // void increaseCapacity() override { assert(!"increaseCapacity not implemented"); } + inline size_t indexLabelCount() const override { return this->labelToIdLookup_.size(); } + VecSimQueryReply *topKQuery(const void *queryBlob, size_t k, + VecSimQueryParams *queryParams) const override { + const auto &res = raft::device_resources_manager::get_device_resources(); + auto result_list = new VecSimQueryReply(this->allocator); + auto nVectors = this->indexSize(); + if (nVectors == 0 || k == 0 || !index_.has_value()) { + return result_list; + } + // Ensure we are not trying to retrieve more vectors than exist in the + // index + k = std::min(k, nVectors); + // Allocate memory on device for search vector + auto vector_data_gpu = + raft::make_device_matrix(res, 1, this->dim); + // Allocate memory on device for neighbor and distance results + auto neighbors_gpu = raft::make_device_matrix(res, 1, k); + auto distances_gpu = raft::make_device_matrix(res, 1, k); + // Copy query vector to device + raft::copy(vector_data_gpu.data_handle(), static_cast(queryBlob), + this->dim, res.get_stream()); + auto bitset_filter = raft::neighbors::filtering::bitset_filter(deleted_indices_->view()); + + // Perform correct search based on index type + if (std::holds_alternative(*index_)) { + raft::neighbors::ivf_flat::search_with_filtering( + res, std::get(search_params_), + std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), + neighbors_gpu.view(), distances_gpu.view(), bitset_filter); + } else { + raft::neighbors::ivf_pq::search_with_filtering( + res, std::get(search_params_), + std::get(*index_), raft::make_const_mdspan(vector_data_gpu.view()), + neighbors_gpu.view(), distances_gpu.view(), bitset_filter); + } + + // Allocate host buffers to hold returned results + auto neighbors = vecsim_stl::vector(k, this->allocator); + auto distances = vecsim_stl::vector(k, this->allocator); + // Copy data back from device to host + raft::copy(neighbors.data(), neighbors_gpu.data_handle(), k, res.get_stream()); + raft::copy(distances.data(), distances_gpu.data_handle(), k, res.get_stream()); + + result_list->results.resize(k); + // Ensure search is complete and data have been copied back before + // building query result objects on host + res.sync_stream(); + + for (auto i = 0; i < k; ++i) { + result_list->results[i].id = idToLabelLookup_[neighbors[i]]; + if (cosine_postprocess_) { + result_list->results[i].score = 1.0f - distances[i]; + } else { + result_list->results[i].score = distances[i]; + } + } + + return result_list; + } + + virtual VecSimQueryReply *rangeQuery(const void *queryBlob, double radius, + VecSimQueryParams *queryParams) const override { + assert(!"RangeQuery not implemented"); + return nullptr; + } + VecSimInfoIterator *infoIterator() const override { + assert(!"infoIterator not implemented"); + return nullptr; + } + virtual VecSimBatchIterator *newBatchIterator(const void *queryBlob, + VecSimQueryParams *queryParams) const override { + assert(!"newBatchIterator not implemented"); + return nullptr; + } + bool preferAdHocSearch(size_t subsetSize, size_t k, bool initial_check) const override { + assert(!"preferAdHocSearch not implemented"); + return false; + } + + virtual uint32_t nLists() const override { + return std::visit([](auto &¶ms) { return params.n_lists; }, build_params_); + } + + size_t indexSize() const override { + auto result = size_t{}; + if (index_) { + if (std::holds_alternative(*index_)) { + result = std::get(*index_).size(); + } else { + result = std::get(*index_).size(); + } + } + return result - this->numDeleted_; + } + VecSimIndexBasicInfo basicInfo() const override { + VecSimIndexBasicInfo info = this->getBasicInfo(); + if (std::holds_alternative(build_params_)) { + info.algo = VecSimAlgo_RAFT_IVFFLAT; + } else { + info.algo = VecSimAlgo_RAFT_IVFPQ; + } + info.isTiered = false; + return info; + } + VecSimIndexInfo info() const override { + VecSimIndexInfo info; + info.commonInfo = this->getCommonInfo(); + info.raftIvfInfo.nLists = nLists(); + if (std::holds_alternative(build_params_)) { + info.commonInfo.basicInfo.algo = VecSimAlgo_RAFT_IVFPQ; + const auto build_params_pq = + std::get(build_params_); + info.raftIvfInfo.pqBits = build_params_pq.pq_bits; + info.raftIvfInfo.pqDim = build_params_pq.pq_dim; + } else { + info.commonInfo.basicInfo.algo = VecSimAlgo_RAFT_IVFFLAT; + } + return info; + } + + virtual inline void setNProbes(uint32_t n_probes) override { + std::visit([n_probes](auto &¶ms) { params.n_probes = n_probes; }, search_params_); + } + +private: + // Store build params to allow for index build on first batch + // insertion + build_params_t build_params_; + // Store search params to use with each search after initializing in + // constructor + search_params_t search_params_; + // Use a std::optional to allow building of the index on first batch + // insertion + std::optional index_; + // Bitset used for deleteVectors and search filtering. + std::optional> deleted_indices_; + internal_idx_t numDeleted_ = 0; + + vecsim_stl::vector idToLabelLookup_; + vecsim_stl::unordered_map labelToIdLookup_; + + bool cosine_postprocess_ = false; +}; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_interface.h b/src/VecSim/algorithms/raft_ivf/ivf_interface.h new file mode 100644 index 000000000..35c2faa34 --- /dev/null +++ b/src/VecSim/algorithms/raft_ivf/ivf_interface.h @@ -0,0 +1,21 @@ +#pragma once + +#include "VecSim/vec_sim.h" +// For VecSimIndexAbstract +#include "VecSim/vec_sim_index.h" +// For labelType +#include "VecSim/vec_sim_common.h" + +// Non-CUDA Interface of the RaftIVF index to avoid importing CUDA code +// in the tiered index. +template +struct RaftIvfInterface : public VecSimIndexAbstract { + RaftIvfInterface(const AbstractIndexInitParams ¶ms) + : VecSimIndexAbstract(params) {} + virtual uint32_t nLists() const = 0; + virtual inline void setNProbes(uint32_t n_probes) = 0; + + virtual int addVectorBatch(const void *vector_data, labelType *label, size_t batch_size, + void *auxiliaryCtx = nullptr) = 0; + virtual vecsim_stl::set getLabelsSet() const = 0; +}; diff --git a/src/VecSim/algorithms/raft_ivf/ivf_tiered.h b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h new file mode 100644 index 000000000..3bc5b241e --- /dev/null +++ b/src/VecSim/algorithms/raft_ivf/ivf_tiered.h @@ -0,0 +1,247 @@ +#pragma once + +#include +#include +#include "VecSim/algorithms/raft_ivf/ivf_interface.h" +#include "VecSim/vec_sim_tiered_index.h" + +struct RAFTTransferJob : public AsyncJob { + bool force_ = false; + RAFTTransferJob(std::shared_ptr allocator, JobCallback insertCb, + VecSimIndex *index_, bool force = false) + : AsyncJob{allocator, RAFT_TRANSFER_JOB, insertCb, index_}, force_{force} {} +}; + +template +struct TieredRaftIvfIndex : public VecSimTieredIndex { + TieredRaftIvfIndex(RaftIvfInterface *raftIvfIndex, + BruteForceIndex *bf_index, + const TieredIndexParams &tieredParams, + std::shared_ptr allocator) + : VecSimTieredIndex(raftIvfIndex, bf_index, tieredParams, allocator) { + assert( + raftIvfIndex->nLists() < this->flatBufferLimit && + "The flat buffer limit must be greater than the number of lists in the backend index"); + this->minVectorsInit = + std::max((size_t)1, tieredParams.specificParams.tieredRaftIvfParams.minVectorsInit); + } + ~TieredRaftIvfIndex() { + // Delete all the pending jobs + } + + int addVector(const void *blob, labelType label, void *auxiliaryCtx) override { + int ret = 1; + // If the flat index is full, write to the backend index + if (this->frontendIndex->indexSize() >= this->flatBufferLimit) { + // If the backend index is empty, build it with all the vectors + // Otherwise, just add the vector to the backend index + auto temp_job = RAFTTransferJob(this->allocator, executeTransferJobWrapper, this, true); + executeTransferJob(&temp_job); + } + + // If the backend index is already built and that the write mode is in place + // add the vector to the backend index + if (this->backendIndex->indexSize() > 0 && this->getWriteMode() == VecSim_WriteInPlace) { + this->mainIndexGuard.lock(); + ret = this->backendIndex->addVector(blob, label); + this->mainIndexGuard.unlock(); + return ret; + } + + // Otherwise, add the vector to the flat index + this->flatIndexGuard.lock(); + ret = this->frontendIndex->addVector(blob, label); + this->flatIndexGuard.unlock(); + + // Submit a transfer job + AsyncJob *new_insert_job = + new (this->allocator) RAFTTransferJob(this->allocator, executeTransferJobWrapper, this); + this->submitSingleJob(new_insert_job); + + return ret; + } + + int deleteVector(labelType label) override { + int num_deleted_vectors = 0; + this->flatIndexGuard.lock_shared(); + if (this->frontendIndex->isLabelExists(label)) { + this->flatIndexGuard.unlock_shared(); + this->flatIndexGuard.lock(); + // Check again if the label exists, as it may have been removed while we released the + // lock. + if (this->frontendIndex->isLabelExists(label)) { + // Remove every id that corresponds the label from the flat buffer. + auto updated_ids = this->frontendIndex->deleteVectorAndGetUpdatedIds(label); + num_deleted_vectors += updated_ids.size(); + } + this->flatIndexGuard.unlock(); + } else { + this->flatIndexGuard.unlock_shared(); + } + + // delete in place. TODO: Add async job for this + this->mainIndexGuard.lock(); + num_deleted_vectors += this->backendIndex->deleteVector(label); + this->mainIndexGuard.unlock(); + return num_deleted_vectors; + } + + size_t indexSize() const override { + this->flatIndexGuard.lock_shared(); + this->mainIndexGuard.lock_shared(); + size_t result = (this->backendIndex->indexSize() + this->frontendIndex->indexSize()); + this->flatIndexGuard.unlock_shared(); + this->mainIndexGuard.unlock_shared(); + return result; + } + + size_t indexLabelCount() const override { + this->flatIndexGuard.lock_shared(); + this->mainIndexGuard.lock_shared(); + auto flat_labels = this->frontendIndex->getLabelsSet(); + auto raft_ivf_labels = this->getBackendIndex()->getLabelsSet(); + this->flatIndexGuard.unlock_shared(); + this->mainIndexGuard.unlock_shared(); + std::vector output; + std::set_union(flat_labels.begin(), flat_labels.end(), raft_ivf_labels.begin(), + raft_ivf_labels.end(), std::back_inserter(output)); + return output.size(); + } + + size_t indexCapacity() const override { + return (this->backendIndex->indexCapacity() + this->frontendIndex->indexCapacity()); + } + + double getDistanceFrom_Unsafe(labelType label, const void *blob) const override { + auto flat_dist = this->frontendIndex->getDistanceFrom_Unsafe(label, blob); + auto raft_dist = this->backendIndex->getDistanceFrom_Unsafe(label, blob); + return std::fmin(flat_dist, raft_dist); + } + + static void executeTransferJobWrapper(AsyncJob *job) { + if (job->isValid) { + auto *transfer_job = reinterpret_cast(job); + auto *job_index = + reinterpret_cast *>(transfer_job->index); + job_index->executeTransferJob(transfer_job); + } + delete job; + } + + VecSimIndexBasicInfo basicInfo() const override { + VecSimIndexBasicInfo info = this->backendIndex->getBasicInfo(); + info.isTiered = true; + return info; + } + + VecSimBatchIterator *newBatchIterator(const void *queryBlob, + VecSimQueryParams *queryParams) const override { + assert(!"newBatchIterator not implemented"); + return nullptr; + } + + inline void setLastSearchMode(VecSearchMode mode) override {} + + void runGC() override {} + + void acquireSharedLocks() override { + this->flatIndexGuard.lock_shared(); + this->mainIndexGuard.lock_shared(); + } + + void releaseSharedLocks() override { + this->flatIndexGuard.unlock_shared(); + this->mainIndexGuard.unlock_shared(); + } + + inline void setNProbes(uint32_t n_probes) { + this->mainIndexGuard.lock(); + this->getBackendIndex()->setNProbes(n_probes); + this->mainIndexGuard.unlock(); + } + +private: + size_t minVectorsInit = 1; + + // This ptr is designating the latest transfer job. It is protected by flat buffer lock + + inline auto *getBackendIndex() const { + return dynamic_cast *>(this->backendIndex); + } + + void executeTransferJob(RAFTTransferJob *job) { + size_t nVectors = this->frontendIndex->indexSize(); + // No vectors to transfer + if (nVectors == 0) { + return; + } + + // Don't transfer less than nLists * minVectorsInit vectors if the backend index is empty + // (for kmeans initialization purposes) + if (!job->force_) { + auto main_nVectors = this->backendIndex->indexSize(); + size_t min_nVectors = 1; + if (main_nVectors == 0) + min_nVectors = this->minVectorsInit * getBackendIndex()->nLists(); + + if (nVectors < min_nVectors) { + return; + } + } + + this->flatIndexGuard.lock(); + // Check that the job has not been cancelled while waiting for the lock + if (!job->isValid) { + this->flatIndexGuard.unlock(); + return; + } + // Check that there are still vectors to transfer after exclusive lock + nVectors = this->frontendIndex->indexSize(); + if (nVectors == 0) { + this->flatIndexGuard.unlock(); + return; + } + + auto dim = this->backendIndex->getDim(); + const auto &vectorBlocks = this->frontendIndex->getVectorBlocks(); + auto *vectorData = (DataType *)this->allocator->allocate(nVectors * dim * sizeof(DataType)); + auto *labelData = (labelType *)this->allocator->allocate(nVectors * sizeof(labelType)); + + // Transfer vectors to a contiguous host buffer + auto *curr_ptr = vectorData; + for (std::uint32_t block_id = 0; block_id < vectorBlocks.size(); ++block_id) { + const auto *in_begin = + reinterpret_cast(vectorBlocks[block_id].getElement(0)); + auto length = vectorBlocks[block_id].getLength(); + std::copy_n(in_begin, length * dim, curr_ptr); + curr_ptr += length * dim; + } + + std::copy_n(this->frontendIndex->getLabels().data(), nVectors, labelData); + this->frontendIndex->clear(); + + // Lock the main index before unlocking the front index so that both indexes are not empty + // at the same time + this->mainIndexGuard.lock(); + this->flatIndexGuard.unlock(); + + // Add the vectors to the backend index + getBackendIndex()->addVectorBatch(vectorData, labelData, nVectors); + this->mainIndexGuard.unlock(); + this->allocator->free_allocation(vectorData); + this->allocator->free_allocation(labelData); + } + +#ifdef BUILD_TESTS + INDEX_TEST_FRIEND_CLASS(BM_VecSimBasics) + INDEX_TEST_FRIEND_CLASS(BM_VecSimCommon) + INDEX_TEST_FRIEND_CLASS(BM_VecSimIndex); + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_transferJob_Test) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_transferJobAsync_Test) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_transferJob_inplace_Test) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_deleteVector_backend_Test) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_searchMetricCosine_Test) + INDEX_TEST_FRIEND_CLASS(RaftIvfTieredTest_searchMetricIP_Test) +#endif +}; diff --git a/src/VecSim/index_factories/index_factory.cpp b/src/VecSim/index_factories/index_factory.cpp index e71affa4f..60365b45c 100644 --- a/src/VecSim/index_factories/index_factory.cpp +++ b/src/VecSim/index_factories/index_factory.cpp @@ -8,6 +8,9 @@ #include "hnsw_factory.h" #include "brute_force_factory.h" #include "tiered_factory.h" +#ifdef USE_CUDA +#include "raft_ivf_factory.h" +#endif #include "VecSim/vec_sim_index.h" namespace VecSimFactory { @@ -25,6 +28,16 @@ VecSimIndex *NewIndex(const VecSimParams *params) { index = BruteForceFactory::NewIndex(params); break; } + case VecSimAlgo_RAFT_IVFFLAT: + case VecSimAlgo_RAFT_IVFPQ: { +#ifdef USE_CUDA + index = RaftIvfFactory::NewIndex(¶ms->algoParams.raftIvfParams); +#else + throw std::runtime_error( + "RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif + break; + } case VecSimAlgo_TIERED: { index = TieredFactory::NewIndex(¶ms->algoParams.tieredParams); break; @@ -42,6 +55,13 @@ size_t EstimateInitialSize(const VecSimParams *params) { return HNSWFactory::EstimateInitialSize(¶ms->algoParams.hnswParams); case VecSimAlgo_BF: return BruteForceFactory::EstimateInitialSize(¶ms->algoParams.bfParams); + case VecSimAlgo_RAFT_IVFFLAT: + case VecSimAlgo_RAFT_IVFPQ: +#ifdef USE_CUDA + return RaftIvfFactory::EstimateInitialSize(¶ms->algoParams.raftIvfParams); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif case VecSimAlgo_TIERED: return TieredFactory::EstimateInitialSize(¶ms->algoParams.tieredParams); } @@ -54,6 +74,13 @@ size_t EstimateElementSize(const VecSimParams *params) { return HNSWFactory::EstimateElementSize(¶ms->algoParams.hnswParams); case VecSimAlgo_BF: return BruteForceFactory::EstimateElementSize(¶ms->algoParams.bfParams); + case VecSimAlgo_RAFT_IVFFLAT: + case VecSimAlgo_RAFT_IVFPQ: +#ifdef USE_CUDA + return RaftIvfFactory::EstimateElementSize(¶ms->algoParams.raftIvfParams); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif case VecSimAlgo_TIERED: return TieredFactory::EstimateElementSize(¶ms->algoParams.tieredParams); } diff --git a/src/VecSim/index_factories/raft_ivf_factory.cu b/src/VecSim/index_factories/raft_ivf_factory.cu new file mode 100644 index 000000000..e499cea89 --- /dev/null +++ b/src/VecSim/index_factories/raft_ivf_factory.cu @@ -0,0 +1,83 @@ +#include "VecSim/index_factories/brute_force_factory.h" +#include "VecSim/algorithms/raft_ivf/ivf.cuh" + +namespace RaftIvfFactory { + +static AbstractIndexInitParams NewAbstractInitParams(const VecSimParams *params) { + + const RaftIvfParams *raftIvfParams = ¶ms->algoParams.raftIvfParams; + AbstractIndexInitParams abstractInitParams = { + .allocator = VecSimAllocator::newVecsimAllocator(), + .dim = raftIvfParams->dim, + .vecType = raftIvfParams->type, + .metric = raftIvfParams->metric, + //.multi = raftIvfParams->multi, + //.logCtx = params->logCtx + }; + return abstractInitParams; +} + +VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams, + const AbstractIndexInitParams &abstractInitParams) { + assert(raftIvfParams->type == VecSimType_FLOAT32 && "Invalid IVF data type algorithm"); + if (raftIvfParams->type == VecSimType_FLOAT32) { + return new (abstractInitParams.allocator) + RaftIvfIndex(raftIvfParams, abstractInitParams); + } + + // If we got here something is wrong. + return NULL; +} + +VecSimIndex *NewIndex(const VecSimParams *params) { + const RaftIvfParams *raftIvfParams = ¶ms->algoParams.raftIvfParams; + AbstractIndexInitParams abstractInitParams = NewAbstractInitParams(params); + return NewIndex(raftIvfParams, NewAbstractInitParams(params)); +} + +VecSimIndex *NewIndex(const RaftIvfParams *raftIvfParams) { + VecSimParams params = {.algoParams{.raftIvfParams = RaftIvfParams{*raftIvfParams}}}; + return NewIndex(¶ms); +} + +size_t EstimateInitialSize(const RaftIvfParams *raftIvfParams) { + size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); + + // Constant part (not effected by parameters). + size_t est = sizeof(VecSimAllocator) + allocations_overhead; + est += sizeof(RaftIvfIndex); // Object size + if (!raftIvfParams->usePQ) { + // Size of each cluster data + est += raftIvfParams->nLists * + sizeof(raft::neighbors::ivf_flat::list_data); + // Vector of shared ptr to cluster + est += raftIvfParams->nLists * + sizeof(std::shared_ptr>); + } else { + // Size of each cluster data + est += raftIvfParams->nLists * sizeof(raft::neighbors::ivf_pq::list_data); + // accum_sorted_sizes_ Array + est += raftIvfParams->nLists * sizeof(std::int64_t); + // vector of shared ptr to cluster + est += raftIvfParams->nLists * + sizeof(std::shared_ptr>); + } + return est; +} + +size_t EstimateElementSize(const RaftIvfParams *raftIvfParams) { + // Those elements are stored only on GPU. + size_t est = 0; + if (!raftIvfParams->usePQ) { + // Size of vec + size of label. + est += raftIvfParams->dim * VecSimType_sizeof(raftIvfParams->type) + sizeof(labelType); + } else { + size_t pq_dim = raftIvfParams->pqDim; + if (pq_dim == 0) // Estimation. + pq_dim = raftIvfParams->dim >= 128 ? raftIvfParams->dim / 2 : raftIvfParams->dim; + // Size of vec after compression + size of label + est += raftIvfParams->pqBits * pq_dim + sizeof(labelType); + } + return est; +} +}; // namespace RaftIvfFactory diff --git a/src/VecSim/index_factories/raft_ivf_factory.h b/src/VecSim/index_factories/raft_ivf_factory.h new file mode 100644 index 000000000..040c1c9d9 --- /dev/null +++ b/src/VecSim/index_factories/raft_ivf_factory.h @@ -0,0 +1,18 @@ +#pragma once + +#include // size_t +#include // std::shared_ptr + +#include "VecSim/vec_sim.h" //typedef VecSimIndex +#include "VecSim/vec_sim_common.h" // RaftIvfParams +#include "VecSim/memory/vecsim_malloc.h" // VecSimAllocator +#include "VecSim/vec_sim_index.h" + +namespace RaftIvfFactory { + +VecSimIndex *NewIndex(const VecSimParams *params); +VecSimIndex *NewIndex(const RaftIvfParams *params); +size_t EstimateInitialSize(const RaftIvfParams *params); +size_t EstimateElementSize(const RaftIvfParams *params); + +}; // namespace RaftIvfFactory diff --git a/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp b/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp new file mode 100644 index 000000000..2befa3a13 --- /dev/null +++ b/src/VecSim/index_factories/raft_ivf_tiered_factory.cpp @@ -0,0 +1,69 @@ +#include "VecSim/index_factories/brute_force_factory.h" +#include "VecSim/algorithms/raft_ivf/ivf_tiered.h" +#include "VecSim/algorithms/raft_ivf/ivf_interface.h" +#include "VecSim/index_factories/tiered_factory.h" +#include "VecSim/index_factories/raft_ivf_factory.h" + +namespace TieredRaftIvfFactory { + +VecSimIndex *NewIndex(const TieredIndexParams *params) { + assert(params->primaryIndexParams->algoParams.raftIvfParams.type == VecSimType_FLOAT32 && + "Invalid IVF data type algorithm"); + + using DataType = float; + using DistType = float; + // initialize raft index + auto *raft_index = reinterpret_cast *>( + RaftIvfFactory::NewIndex(params->primaryIndexParams)); + // initialize brute force index + BFParams bf_params = { + .type = params->primaryIndexParams->algoParams.raftIvfParams.type, + .dim = params->primaryIndexParams->algoParams.raftIvfParams.dim, + .metric = params->primaryIndexParams->algoParams.raftIvfParams.metric, + .multi = params->primaryIndexParams->algoParams.raftIvfParams.multi, + //.blockSize = params->primaryIndexParams->algoParams.raftIvfParams.blockSize + }; + + std::shared_ptr flat_allocator = VecSimAllocator::newVecsimAllocator(); + AbstractIndexInitParams abstractInitParams = {.allocator = flat_allocator, + .dim = bf_params.dim, + .vecType = bf_params.type, + .metric = bf_params.metric, + .blockSize = bf_params.blockSize, + .multi = bf_params.multi, + .logCtx = params->primaryIndexParams->logCtx}; + auto frontendIndex = static_cast *>( + BruteForceFactory::NewIndex(&bf_params, abstractInitParams)); + + // Create new tiered RaftIVF index + std::shared_ptr management_layer_allocator = + VecSimAllocator::newVecsimAllocator(); + + return new (management_layer_allocator) TieredRaftIvfIndex( + raft_index, frontendIndex, *params, management_layer_allocator); +} + +// The size estimation is the sum of the buffer (brute force) and main index initial sizes +// estimations, plus the tiered index class size. Note it does not include the size of internal +// containers such as the job queue, as those depend on the user implementation. +size_t EstimateInitialSize(const TieredIndexParams *params) { + auto raft_ivf_params = params->primaryIndexParams->algoParams.raftIvfParams; + + // Add size estimation of VecSimTieredIndex sub indexes. + size_t est = RaftIvfFactory::EstimateInitialSize(&raft_ivf_params); + + // Management layer allocator overhead. + size_t allocations_overhead = VecSimAllocator::getAllocationOverheadSize(); + est += sizeof(VecSimAllocator) + allocations_overhead; + + // Size of the TieredRaftIvfIndex struct. + if (raft_ivf_params.type == VecSimType_FLOAT32) { + est += sizeof(TieredRaftIvfIndex); + } else if (raft_ivf_params.type == VecSimType_FLOAT64) { + est += sizeof(TieredRaftIvfIndex); + } + + return est; +} + +}; // namespace TieredRaftIvfFactory diff --git a/src/VecSim/index_factories/raft_ivf_tiered_factory.h b/src/VecSim/index_factories/raft_ivf_tiered_factory.h new file mode 100644 index 000000000..a89486531 --- /dev/null +++ b/src/VecSim/index_factories/raft_ivf_tiered_factory.h @@ -0,0 +1,17 @@ +#pragma once + +#include "VecSim/vec_sim.h" //typedef VecSimIndex +#include "VecSim/vec_sim_common.h" // RaftIvfParams +#include "VecSim/memory/vecsim_malloc.h" // VecSimAllocator +#include "VecSim/vec_sim_index.h" + +namespace TieredRaftIvfFactory { + +VecSimIndex *NewIndex(const TieredIndexParams *params); + +// The size estimation is the sum of the buffer (brute force) and main index initial sizes +// estimations, plus the tiered index class size. Note it does not include the size of internal +// containers such as the job queue, as those depend on the user implementation. +size_t EstimateInitialSize(const TieredIndexParams *params); + +}; // namespace TieredRaftIvfFactory diff --git a/src/VecSim/index_factories/tiered_factory.cpp b/src/VecSim/index_factories/tiered_factory.cpp index fbf89c481..fbc5f8d05 100644 --- a/src/VecSim/index_factories/tiered_factory.cpp +++ b/src/VecSim/index_factories/tiered_factory.cpp @@ -8,6 +8,11 @@ #include "VecSim/index_factories/hnsw_factory.h" #include "VecSim/index_factories/brute_force_factory.h" +#ifdef USE_CUDA +#include "VecSim/index_factories/raft_ivf_tiered_factory.h" +#include "VecSim/index_factories/raft_ivf_factory.h" +#endif + #include "VecSim/algorithms/hnsw/hnsw_tiered.h" namespace TieredFactory { @@ -89,6 +94,13 @@ VecSimIndex *NewIndex(const TieredIndexParams *params) { } else if (type == VecSimType_FLOAT64) { return TieredHNSWFactory::NewIndex(params); } + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || + params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { +#ifdef USE_CUDA + return TieredRaftIvfFactory::NewIndex(params); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif } return nullptr; // Invalid algorithm or type. } @@ -99,6 +111,13 @@ size_t EstimateInitialSize(const TieredIndexParams *params) { BFParams bf_params{}; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est += TieredHNSWFactory::EstimateInitialSize(params, bf_params); + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || + params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { +#ifdef USE_CUDA + est += TieredRaftIvfFactory::EstimateInitialSize(params); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif } est += BruteForceFactory::EstimateInitialSize(&bf_params); @@ -109,6 +128,14 @@ size_t EstimateElementSize(const TieredIndexParams *params) { size_t est = 0; if (params->primaryIndexParams->algo == VecSimAlgo_HNSWLIB) { est = HNSWFactory::EstimateElementSize(¶ms->primaryIndexParams->algoParams.hnswParams); + } else if (params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFFLAT || + params->primaryIndexParams->algo == VecSimAlgo_RAFT_IVFPQ) { +#ifdef USE_CUDA + est = RaftIvfFactory::EstimateElementSize( + ¶ms->primaryIndexParams->algoParams.raftIvfParams); +#else + throw std::runtime_error("RAFT_IVFFLAT and RAFT_IVFPQ are not supported in CPU version"); +#endif } return est; } diff --git a/src/VecSim/memory/vecsim_malloc.h b/src/VecSim/memory/vecsim_malloc.h index e25cf6e6b..56f681ac3 100644 --- a/src/VecSim/memory/vecsim_malloc.h +++ b/src/VecSim/memory/vecsim_malloc.h @@ -25,7 +25,7 @@ struct VecSimAllocator { static size_t allocation_header_size; static VecSimMemoryFunctions memFunctions; - VecSimAllocator() : allocated(std::atomic_uint64_t(sizeof(VecSimAllocator))) {} + VecSimAllocator() : allocated((sizeof(VecSimAllocator))) {} public: static std::shared_ptr newVecsimAllocator(); diff --git a/src/VecSim/tombstone_interface.h b/src/VecSim/tombstone_interface.h index 0d864fa7b..ff4c43c0d 100644 --- a/src/VecSim/tombstone_interface.h +++ b/src/VecSim/tombstone_interface.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include "vec_sim_common.h" /* diff --git a/src/VecSim/utils/vec_utils.cpp b/src/VecSim/utils/vec_utils.cpp index b061bddcf..d782210a3 100644 --- a/src/VecSim/utils/vec_utils.cpp +++ b/src/VecSim/utils/vec_utils.cpp @@ -15,6 +15,8 @@ const char *VecSimCommonStrings::ALGORITHM_STRING = "ALGORITHM"; const char *VecSimCommonStrings::FLAT_STRING = "FLAT"; const char *VecSimCommonStrings::HNSW_STRING = "HNSW"; +const char *VecSimCommonStrings::RAFTIVFFLAT_STRING = "RAFT_IVF_FLAT"; +const char *VecSimCommonStrings::RAFTIVFPQ_STRING = "RAFT_IVF_PQ"; const char *VecSimCommonStrings::TIERED_STRING = "TIERED"; const char *VecSimCommonStrings::TYPE_STRING = "TYPE"; @@ -125,6 +127,10 @@ const char *VecSimAlgo_ToString(VecSimAlgo vecsimAlgo) { return VecSimCommonStrings::FLAT_STRING; case VecSimAlgo_HNSWLIB: return VecSimCommonStrings::HNSW_STRING; + case VecSimAlgo_RAFT_IVFFLAT: + return VecSimCommonStrings::RAFTIVFFLAT_STRING; + case VecSimAlgo_RAFT_IVFPQ: + return VecSimCommonStrings::RAFTIVFPQ_STRING; case VecSimAlgo_TIERED: return VecSimCommonStrings::TIERED_STRING; } diff --git a/src/VecSim/utils/vec_utils.h b/src/VecSim/utils/vec_utils.h index 79c8011e7..723a1c2b9 100644 --- a/src/VecSim/utils/vec_utils.h +++ b/src/VecSim/utils/vec_utils.h @@ -18,6 +18,8 @@ struct VecSimCommonStrings { static const char *ALGORITHM_STRING; static const char *FLAT_STRING; static const char *HNSW_STRING; + static const char *RAFTIVFFLAT_STRING; + static const char *RAFTIVFPQ_STRING; static const char *TIERED_STRING; static const char *TYPE_STRING; diff --git a/src/VecSim/vec_sim_common.h b/src/VecSim/vec_sim_common.h index 7eeee6d6e..60d8a23c9 100644 --- a/src/VecSim/vec_sim_common.h +++ b/src/VecSim/vec_sim_common.h @@ -38,11 +38,26 @@ typedef enum { } VecSimType; // Algorithm type/library. -typedef enum { VecSimAlgo_BF, VecSimAlgo_HNSWLIB, VecSimAlgo_TIERED } VecSimAlgo; +typedef enum { + VecSimAlgo_BF, + VecSimAlgo_HNSWLIB, + VecSimAlgo_TIERED, + VecSimAlgo_RAFT_IVFFLAT, + VecSimAlgo_RAFT_IVFPQ +} VecSimAlgo; // Distance metric typedef enum { VecSimMetric_L2, VecSimMetric_IP, VecSimMetric_Cosine } VecSimMetric; +// Codebook kind for IVFPQ indexes +typedef enum { + RaftIVFPQCodebookKind_PerCluster, + RaftIVFPQCodebookKind_PerSubspace +} RaftIVFPQCodebookKind; + +// CUDA types supported by GPU-accelerated indexes +typedef enum { CUDAType_R_32F, CUDAType_R_16F, CUDAType_R_8U } CudaType; + typedef size_t labelType; typedef unsigned int idType; @@ -118,6 +133,11 @@ typedef struct { // all the ready swap jobs in a batch. } TieredHNSWParams; +// A struct that contains Raft IVF tiered index specific params. +typedef struct { + size_t minVectorsInit; // Min. number of vectors per list in Tiered index to init IVF index +} TieredRAFTIVFParams; + // A struct that contains the common tiered index params. typedef struct { void *jobQueue; // External queue that holds the jobs. @@ -128,12 +148,44 @@ typedef struct { VecSimParams *primaryIndexParams; // Parameters to initialize the index. union { TieredHNSWParams tieredHnswParams; + TieredRAFTIVFParams tieredRaftIvfParams; } specificParams; } TieredIndexParams; +typedef struct { + VecSimType type; // Datatype to index. + size_t dim; // Vector's dimension. + VecSimMetric metric; // Distance metric to use in the index. + bool multi; // Determines if the index should multi-index or not. + size_t nLists; // Number of inverted lists + bool conservativeMemoryAllocation; // Use as little GPU memory as possible + size_t kmeans_nIters; // Iterations for kmeans calculation + double kmeans_trainsetFraction; // Fraction of dataset used for kmeans training + unsigned nProbes; // The number of clusters to search + bool usePQ; // If false IVF-Flat is used. If true IVF-PQ is used. + // ***************** IVF-Flat-only parameters ****************** + // The following parameters will be ignored if usePQ is set to true. + bool adaptiveCenters; // If index should be updated for new vectors + + // ******************* IVF-PQ-only parameters ******************* + // The following parameters will be ignored if usePQ is set to false. + size_t pqBits; // Bit length of vector element after PQ compression. + size_t pqDim; // The dimensionality of an encoded vector after PQ + // compression. If set to 0, a heuristic will be used to + // select the dimensionality. + + RaftIVFPQCodebookKind codebookKind; + CudaType lutType; + CudaType internalDistanceType; + double preferredShmemCarveout; // Fraction of GPU's unified memory / L1 + // cache to be used as shared memory + +} RaftIvfParams; + typedef union { HNSWParams hnswParams; BFParams bfParams; + RaftIvfParams raftIvfParams; TieredIndexParams tieredParams; } AlgoParams; @@ -151,6 +203,7 @@ typedef enum { HNSW_REPAIR_NODE_CONNECTIONS_JOB, HNSW_SEARCH_JOB, HNSW_SWAP_JOB, + RAFT_TRANSFER_JOB, INVALID_JOB // to indicate that finding a JobType >= INVALID_JOB is an error } JobType; @@ -233,6 +286,12 @@ typedef struct { char dummy; // For not having this as an empty struct, can be removed after we extend this. } bfInfoStruct; +typedef struct { + size_t nLists; // Number of inverted lists. + size_t pqDim; // Dimensionality of encoded vector after PQ + size_t pqBits; // Bits per encoded vector element after PQ +} raftIvfInfoStruct; + typedef struct HnswTieredInfo { size_t pendingSwapJobsThreshold; } HnswTieredInfo; @@ -242,6 +301,7 @@ typedef struct { // Since we cannot recursively have a struct that contains itself, we need this workaround. union { hnswInfoStruct hnswInfo; + raftIvfInfoStruct raftIvfInfo; } backendInfo; // The backend index info. union { HnswTieredInfo hnswTieredInfo; @@ -265,6 +325,7 @@ typedef struct { union { bfInfoStruct bfInfo; hnswInfoStruct hnswInfo; + raftIvfInfoStruct raftIvfInfo; tieredInfoStruct tieredInfo; }; } VecSimIndexInfo; diff --git a/src/VecSim/vec_sim_tiered_index.h b/src/VecSim/vec_sim_tiered_index.h index bc5f53d71..775521f4d 100644 --- a/src/VecSim/vec_sim_tiered_index.h +++ b/src/VecSim/vec_sim_tiered_index.h @@ -284,6 +284,10 @@ VecSimIndexInfo VecSimTieredIndex::info() const { case VecSimAlgo_HNSWLIB: info.tieredInfo.backendInfo.hnswInfo = backendInfo.hnswInfo; break; + case VecSimAlgo_RAFT_IVFFLAT: + case VecSimAlgo_RAFT_IVFPQ: + info.tieredInfo.backendInfo.raftIvfInfo = backendInfo.raftIvfInfo; + break; case VecSimAlgo_BF: case VecSimAlgo_TIERED: assert(false && "Invalid backend algorithm"); diff --git a/tests/benchmark/CMakeLists.txt b/tests/benchmark/CMakeLists.txt index 8ef952187..5f5256b4f 100644 --- a/tests/benchmark/CMakeLists.txt +++ b/tests/benchmark/CMakeLists.txt @@ -22,7 +22,7 @@ foreach(benchmark IN ITEMS ${BENCHMARKS}) # NOTE: mock_thread_pool.cpp should appear *before* the benchmark files, so we can ensure that the thread pool # globals are initialized before we use them in the benchmark classes (as globals initialization is done by order). add_executable(bm_${benchmark} ../utils/mock_thread_pool.cpp bm_vecsim_general.cpp run_files/bm_${benchmark}.cpp) - target_link_libraries(bm_${benchmark} VectorSimilarity benchmark::benchmark) + target_link_libraries(bm_${benchmark} VectorSimilarity benchmark::benchmark $<$:raft::raft>) endforeach() # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # diff --git a/tests/benchmark/bm_common.h b/tests/benchmark/bm_common.h index 1a3388670..aebb6ae26 100644 --- a/tests/benchmark/bm_common.h +++ b/tests/benchmark/bm_common.h @@ -1,6 +1,9 @@ #pragma once #include "bm_vecsim_index.h" +#ifdef USE_CUDA +#include "VecSim/algorithms/raft_ivf/ivf_tiered.h" +#endif size_t BM_VecSimGeneral::block_size = 1024; @@ -25,11 +28,17 @@ class BM_VecSimCommon : public BM_VecSimIndex { // with respect to the results returned by the flat index. static void TopK_HNSW(benchmark::State &st, unsigned short index_offset = 0); static void TopK_Tiered(benchmark::State &st, unsigned short index_offset = 0); +#ifdef USE_CUDA + // Run TopK using Raft IVF tiered and flat index and calculate the recall of the Raft IVF + // algorithm with respect to the results returned by the flat index. + static void TopK_TieredRaftIVF(benchmark::State &st, unsigned short index_offset = 0); +#endif // Does nothing but returning the index memory. static void Memory_FLAT(benchmark::State &st, unsigned short index_offset = 0); static void Memory_HNSW(benchmark::State &st, unsigned short index_offset = 0); static void Memory_Tiered(benchmark::State &st, unsigned short index_offset = 0); + static void Memory_TieredRaftIVF(benchmark::State &st, unsigned short index_offset = 0); }; template @@ -82,6 +91,16 @@ void BM_VecSimCommon::Memory_Tiered(benchmark::State &st, st.counters["memory"] = (double)VecSimIndex_Info(INDICES[VecSimAlgo_TIERED + index_offset]).commonInfo.memory; } +template +void BM_VecSimCommon::Memory_TieredRaftIVF(benchmark::State &st, + unsigned short index_offset) { + + for (auto _ : st) { + // Do nothing... + } + st.counters["memory"] = + (double)VecSimIndex_Info(INDICES[VecSimAlgo_RAFT_IVFFLAT + index_offset]).commonInfo.memory; +} // TopK search BM @@ -157,6 +176,78 @@ void BM_VecSimCommon::TopK_Tiered(benchmark::State &st, unsigned s st.counters["num_threads"] = (double)BM_VecSimGeneral::mock_thread_pool.thread_pool_size; } +#ifdef USE_CUDA +template +void BM_VecSimCommon::TopK_TieredRaftIVF(benchmark::State &st, + unsigned short index_offset) { + size_t k = st.range(0); + size_t n_probes = st.range(1); + std::atomic_int correct = 0; + std::atomic_int iter = 0; + auto *tiered_index = + reinterpret_cast *>(INDICES[VecSimAlgo_RAFT_IVFFLAT + index_offset]); + size_t total_iters = 50; + tiered_index->setNProbes(n_probes); + VecSimQueryReply *all_results[total_iters]; + + // Declare 2 lambda to avoid changing AsyncJob type for the JobMock. + auto parallel_knn_search_flat = [](AsyncJob *job) { + auto *search_job = reinterpret_cast(job); + VecSimQueryParams query_params{.batchSize = 1}; + size_t cur_iter = search_job->iter; + auto results = VecSimIndex_TopKQuery(INDICES[VecSimAlgo_RAFT_IVFFLAT], + QUERIES[cur_iter % N_QUERIES].data(), search_job->k, + &query_params, BY_SCORE); + search_job->all_results[cur_iter] = results; + delete job; + }; + + auto parallel_knn_search_pq = [](AsyncJob *job) { + auto *search_job = reinterpret_cast(job); + VecSimQueryParams query_params{.batchSize = 1}; + size_t cur_iter = search_job->iter; + auto results = VecSimIndex_TopKQuery(INDICES[VecSimAlgo_RAFT_IVFPQ], + QUERIES[cur_iter % N_QUERIES].data(), search_job->k, + &query_params, BY_SCORE); + search_job->all_results[cur_iter] = results; + delete job; + }; + + for (auto _ : st) { + if (index_offset == 0) // Flat + { + auto search_job = new (tiered_index->getAllocator()) + tieredIndexMock::SearchJobMock(tiered_index->getAllocator(), parallel_knn_search_flat, + tiered_index, k, 0, iter++, all_results); + tiered_index->submitSingleJob(search_job); + } else // PQ + { + auto search_job = new (tiered_index->getAllocator()) + tieredIndexMock::SearchJobMock(tiered_index->getAllocator(), parallel_knn_search_pq, + tiered_index, k, 0, iter++, all_results); + tiered_index->submitSingleJob(search_job); + } + if (iter == total_iters) { + BM_VecSimGeneral::mock_thread_pool.thread_pool_wait(); + } + } + + // Measure recall + for (iter = 0; iter < total_iters; iter++) { + auto bf_results = + VecSimIndex_TopKQuery(INDICES[VecSimAlgo_BF], + QUERIES[iter % N_QUERIES].data(), k, nullptr, BY_SCORE); + BM_VecSimGeneral::MeasureRecall(all_results[iter], bf_results, correct); + + VecSimQueryReply_Free(bf_results); + VecSimQueryReply_Free(all_results[iter]); + } + + st.counters["Recall"] = (float)correct / (float)(k * iter); + st.counters["num_threads"] = (double)BM_VecSimGeneral::mock_thread_pool.thread_pool_size; +} +#endif + #define REGISTER_TopK_BF(BM_CLASS, BM_FUNC) \ BENCHMARK_REGISTER_F(BM_CLASS, BM_FUNC) \ ->Arg(10) \ @@ -189,3 +280,20 @@ void BM_VecSimCommon::TopK_Tiered(benchmark::State &st, unsigned s ->ArgNames({"ef_runtime", "k"}) \ ->Iterations(50) \ ->Unit(benchmark::kMillisecond) + +#ifdef USE_CUDA +#define REGISTER_TopK_TieredRaftIVF(BM_CLASS, BM_FUNC) \ + BENCHMARK_REGISTER_F(BM_CLASS, BM_FUNC) \ + ->Args({10, 200}) \ + ->Args({10, 500}) \ + ->Args({10, 1500}) \ + ->Args({100, 200}) \ + ->Args({100, 500}) \ + ->Args({100, 1500}) \ + ->Args({200, 200}) \ + ->Args({200, 500}) \ + ->Args({200, 1500}) \ + ->ArgNames({"k", "n_probes"}) \ + ->Iterations(50) \ + ->Unit(benchmark::kMillisecond) +#endif diff --git a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h index ae15b3c9e..9a60453c5 100644 --- a/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h +++ b/tests/benchmark/bm_initialization/bm_basics_initialize_fp32.h @@ -20,12 +20,26 @@ BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, Tiered), fp32_ (benchmark::State &st) { Memory_Tiered(st); } BENCHMARK_REGISTER_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, Tiered))->Iterations(1); +#ifdef USE_CUDA +// Memory TieredRaftIVFFlat +BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFFlat), fp32_index_t) +(benchmark::State &st) { Memory_TieredRaftIVF(st); } +BENCHMARK_REGISTER_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFFlat))->Iterations(1); +// Memory TieredRaftIVFPQ +BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFPQ), fp32_index_t) +(benchmark::State &st) { Memory_TieredRaftIVF(st, 1); } +BENCHMARK_REGISTER_F(BM_VecSimCommon, BM_FUNC_NAME(Memory, TieredRaftIVFPQ))->Iterations(1); +#endif + // AddLabel BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimBasics, BM_ADD_LABEL, fp32_index_t) (benchmark::State &st) { AddLabel(st); } REGISTER_AddLabel(BM_ADD_LABEL, VecSimAlgo_BF); REGISTER_AddLabel(BM_ADD_LABEL, VecSimAlgo_HNSWLIB); - +#ifdef USE_CUDA +REGISTER_AddLabel(BM_ADD_LABEL, VecSimAlgo_RAFT_IVFFLAT); +REGISTER_AddLabel(BM_ADD_LABEL, VecSimAlgo_RAFT_IVFPQ); +#endif // DeleteLabel Registration. Definition is placed in the .cpp file. REGISTER_DeleteLabel(BM_FUNC_NAME(DeleteLabel, BF)); REGISTER_DeleteLabel(BM_FUNC_NAME(DeleteLabel, HNSW)); @@ -45,6 +59,18 @@ BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(TopK, Tiered), fp32_in (benchmark::State &st) { TopK_Tiered(st); } REGISTER_TopK_Tiered(BM_VecSimCommon, BM_FUNC_NAME(TopK, Tiered)); +#ifdef USE_CUDA +// TopK Tiered RAFT IVF Flat +BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVFFLAT), fp32_index_t) +(benchmark::State &st) { TopK_TieredRaftIVF(st, 0); } +REGISTER_TopK_TieredRaftIVF(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVFFLAT)); + +// TopK Tiered RAFT IVF PQ +BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVFPQ), fp32_index_t) +(benchmark::State &st) { TopK_TieredRaftIVF(st, 1); } +REGISTER_TopK_TieredRaftIVF(BM_VecSimCommon, BM_FUNC_NAME(TopK, TieredRaftIVFPQ)); +#endif + // Range BF BENCHMARK_TEMPLATE_DEFINE_F(BM_VecSimBasics, BM_FUNC_NAME(Range, BF), fp32_index_t) (benchmark::State &st) { Range_BF(st); } @@ -72,3 +98,13 @@ BENCHMARK_REGISTER_F(BM_VecSimBasics, BM_DELETE_LABEL_ASYNC) ->Arg(100) ->Arg(BM_VecSimGeneral::block_size) ->ArgName("SwapJobsThreshold"); + +// Tiered RAFT IVF Flat add_async benchmarks +#ifdef USE_CUDA +BENCHMARK_REGISTER_F(BM_VecSimBasics, BM_ADD_LABEL_ASYNC) + ->UNIT_AND_ITERATIONS->Arg(VecSimAlgo_RAFT_IVFFLAT) + ->ArgName("VecSimAlgo_RAFT_IVFFLAT"); +BENCHMARK_REGISTER_F(BM_VecSimBasics, BM_ADD_LABEL_ASYNC) + ->UNIT_AND_ITERATIONS->Arg(VecSimAlgo_RAFT_IVFPQ) + ->ArgName("VecSimAlgo_RAFT_IVFPQ"); +#endif diff --git a/tests/benchmark/bm_vecsim_basics.h b/tests/benchmark/bm_vecsim_basics.h index 0ddfb240d..a57fb3ac1 100644 --- a/tests/benchmark/bm_vecsim_basics.h +++ b/tests/benchmark/bm_vecsim_basics.h @@ -123,7 +123,10 @@ void BM_VecSimBasics::AddLabel_AsyncIngest(benchmark::State &st) { size_t new_label_count = (INDICES[st.range(0)])->indexLabelCount(); // Remove directly inplace from the underline HNSW index. for (size_t label_ = initial_label_count; label_ < new_label_count; label_++) { - VecSimIndex_DeleteVector(INDICES[VecSimAlgo_HNSWLIB], label_); + if (st.range(0) == VecSimAlgo_TIERED) + VecSimIndex_DeleteVector(INDICES[VecSimAlgo_HNSWLIB], label_); + else + VecSimIndex_DeleteVector(INDICES[st.range(0)], label_); } assert(VecSimIndex_IndexSize(INDICES[st.range(0)]) == N_VECTORS); diff --git a/tests/benchmark/bm_vecsim_general.h b/tests/benchmark/bm_vecsim_general.h index 256273faf..21074101f 100644 --- a/tests/benchmark/bm_vecsim_general.h +++ b/tests/benchmark/bm_vecsim_general.h @@ -70,6 +70,42 @@ class BM_VecSimGeneral : public benchmark::Fixture { return params; } + static VecSimParams createDefaultRaftIvfPQParams(size_t dim, uint32_t nLists = 1024, + uint32_t nProbes = 20) { + RaftIvfParams ivfparams = {.dim = dim, + .metric = VecSimMetric_Cosine, + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.5, + .nProbes = nProbes, + .usePQ = true, + .pqBits = 8, + .pqDim = 0, + .codebookKind = RaftIVFPQCodebookKind_PerSubspace, + .lutType = CUDAType_R_32F, + .internalDistanceType = CUDAType_R_32F, + .preferredShmemCarveout = 1.0}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFPQ, + .algoParams = {.raftIvfParams = ivfparams}}; + return params; + } + + static VecSimParams createDefaultRaftIvfFlatParams(size_t dim, uint32_t nLists = 1024, + uint32_t nProbes = 20, + bool adaptiveCenters = true) { + RaftIvfParams ivfparams = {.dim = dim, + .metric = VecSimMetric_Cosine, + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.5, + .nProbes = nProbes, + .usePQ = false, + .adaptiveCenters = adaptiveCenters}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFFLAT, + .algoParams = {.raftIvfParams = ivfparams}}; + return params; + } + // Gets HNSWParams or BFParams parameters struct, and creates new VecSimIndex. template static inline VecSimIndex *CreateNewIndex(IndexParams &index_params) { diff --git a/tests/benchmark/bm_vecsim_index.h b/tests/benchmark/bm_vecsim_index.h index 8b2406a81..436b65f0c 100644 --- a/tests/benchmark/bm_vecsim_index.h +++ b/tests/benchmark/bm_vecsim_index.h @@ -2,6 +2,10 @@ #include "bm_vecsim_general.h" #include "VecSim/index_factories/tiered_factory.h" +#ifdef USE_CUDA +#include "VecSim/index_factories/raft_ivf_tiered_factory.h" +#include "VecSim/algorithms/raft_ivf/ivf_tiered.h" +#endif template class BM_VecSimIndex : public BM_VecSimGeneral { @@ -111,13 +115,52 @@ void BM_VecSimIndex::Initialize() { // Launch the BG threads loop that takes jobs from the queue and executes them. mock_thread_pool.init_threads(); +#ifdef USE_CUDA + // Create RAFFT IVF Flat tiered index. + // Use one unique thread pool for the tiered index by changing the thread pool context. + VecSimParams params_flat = createDefaultRaftIvfFlatParams(dim, 10000, 100, false); + tiered_params = {.jobQueue = &BM_VecSimGeneral::mock_thread_pool.jobQ, + .jobQueueCtx = mock_thread_pool.ctx, + .submitCb = tieredIndexMock::submit_callback, + .flatBufferLimit = n_vectors, + .primaryIndexParams = ¶ms_flat, + .specificParams = {.tieredRaftIvfParams = {.minVectorsInit = + size_t(n_vectors / params_flat.algoParams.raftIvfParams.nLists)}}}; + + auto *tiered_raft_ivf_flat_index = reinterpret_cast *>( + TieredRaftIvfFactory::NewIndex(&tiered_params)); + + indices.push_back(tiered_raft_ivf_flat_index); + + // Create RAFT IVF PQ tiered index. + // Use one unique thread pool for the tiered index by changing the thread pool context. + VecSimParams params_pq = createDefaultRaftIvfPQParams(dim, 5000, 100); + tiered_params = {.jobQueue = &BM_VecSimGeneral::mock_thread_pool.jobQ, + .jobQueueCtx = mock_thread_pool.ctx, + .submitCb = tieredIndexMock::submit_callback, + .flatBufferLimit = n_vectors, + .primaryIndexParams = ¶ms_pq, + .specificParams = {.tieredRaftIvfParams = {.minVectorsInit = + size_t(n_vectors / params_pq.algoParams.raftIvfParams.nLists)}}}; + + auto *tiered_raft_ivf_pq_index = reinterpret_cast *>( + TieredRaftIvfFactory::NewIndex(&tiered_params)); + + indices.push_back(tiered_raft_ivf_pq_index); +#endif + // Add the same vectors to Flat index. for (size_t i = 0; i < n_vectors; ++i) { const char *blob = GetHNSWDataByInternalId(i); // Fot multi value indices, the internal id is not necessarily equal the label. size_t label = CastToHNSW(indices[VecSimAlgo_HNSWLIB])->getExternalLabel(i); VecSimIndex_AddVector(indices[VecSimAlgo_BF], blob, label); +#ifdef USE_CUDA + VecSimIndex_AddVector(indices[VecSimAlgo_RAFT_IVFFLAT], blob, label); + VecSimIndex_AddVector(indices[VecSimAlgo_RAFT_IVFPQ], blob, label); +#endif } + mock_thread_pool.thread_pool_wait(100); // Load the test query vectors form file. Index file path is relative to repository root dir. loadTestVectors(AttachRootPath(test_queries_file), type); diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index e3cb19bdb..0bd5f363b 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -35,6 +35,11 @@ add_executable(test_allocator test_allocator.cpp test_utils.cpp) add_executable(test_spaces test_spaces.cpp) add_executable(test_common ../utils/mock_thread_pool.cpp test_utils.cpp test_common.cpp) +if(USE_CUDA) + add_executable(test_raftivf ../utils/mock_thread_pool.cpp test_raft_ivf_tiered.cpp test_utils.cpp) + target_link_libraries(test_raftivf PUBLIC gtest_main VectorSimilarity PRIVATE raft::raft) +endif() + target_link_libraries(test_hnsw PUBLIC gtest_main VectorSimilarity) target_link_libraries(test_hnsw_parallel PUBLIC gtest_main VectorSimilarity) target_link_libraries(test_bruteforce PUBLIC gtest_main VectorSimilarity) @@ -50,3 +55,7 @@ gtest_discover_tests(test_bruteforce) gtest_discover_tests(test_allocator) gtest_discover_tests(test_spaces) gtest_discover_tests(test_common) + +if(USE_CUDA) + gtest_discover_tests(test_raftivf) +endif() diff --git a/tests/unit/test_raft_ivf_tiered.cpp b/tests/unit/test_raft_ivf_tiered.cpp new file mode 100644 index 000000000..511da4929 --- /dev/null +++ b/tests/unit/test_raft_ivf_tiered.cpp @@ -0,0 +1,391 @@ +#include "gtest/gtest.h" +#include "VecSim/vec_sim.h" +#include "VecSim/vec_sim_common.h" +#include "VecSim/algorithms/raft_ivf/ivf_tiered.h" +#include "VecSim/index_factories/tiered_factory.h" +#include "test_utils.h" +#include +#include +#include + +#include "mock_thread_pool.h" + +template +class RaftIvfTieredTest : public ::testing::Test { +public: + using data_t = typename index_type_t::data_t; + using dist_t = typename index_type_t::dist_t; + + TieredRaftIvfIndex *createTieredIndex(VecSimParams *params, + tieredIndexMock &mock_thread_pool, + size_t flat_buffer_limit = 0) { + TieredIndexParams params_tiered = { + .jobQueue = &mock_thread_pool.jobQ, + .jobQueueCtx = mock_thread_pool.ctx, + .submitCb = tieredIndexMock::submit_callback, + .flatBufferLimit = flat_buffer_limit, + .primaryIndexParams = params, + }; + auto *tiered_index = TieredFactory::NewIndex(¶ms_tiered); + // Set the created tiered index in the index external context (it will take ownership over + // the index, and we'll need to release the ctx at the end of the test. + mock_thread_pool.ctx->index_strong_ref.reset(tiered_index); + + return reinterpret_cast *>(tiered_index); + } +}; + +VecSimParams createDefaultPQParams(size_t dim, uint32_t nLists = 3, uint32_t nProbes = 3) { + RaftIvfParams ivfparams = {.dim = dim, + .metric = VecSimMetric_L2, + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.5, + .nProbes = nProbes, + .usePQ = true, + .pqBits = 8, + .pqDim = 0, + .codebookKind = RaftIVFPQCodebookKind_PerSubspace, + .lutType = CUDAType_R_32F, + .internalDistanceType = CUDAType_R_32F, + .preferredShmemCarveout = 1.0}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFPQ, .algoParams = {.raftIvfParams = ivfparams}}; + return params; +} + +VecSimParams createDefaultFlatParams(size_t dim, uint32_t nLists = 3, uint32_t nProbes = 3) { + RaftIvfParams ivfparams = {.dim = dim, + .metric = VecSimMetric_L2, + .nLists = nLists, + .kmeans_nIters = 20, + .kmeans_trainsetFraction = 0.5, + .nProbes = nProbes, + .usePQ = false}; + VecSimParams params{.algo = VecSimAlgo_RAFT_IVFFLAT, + .algoParams = {.raftIvfParams = ivfparams}}; + return params; +} + +using DataTypeSetFloat = ::testing::Types>; + +TYPED_TEST_SUITE(RaftIvfTieredTest, DataTypeSetFloat); + +TYPED_TEST(RaftIvfTieredTest, end_to_end) { + size_t dim = 4; + size_t flat_buffer_limit = 3; + size_t nLists = 2; + + VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); + auto mock_thread_pool = tieredIndexMock(); + auto *index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + mock_thread_pool.init_threads(); + + VecSimQueryParams queryParams = {.batchSize = 1}; + + ASSERT_EQ(VecSimIndex_IndexSize(index), 0); + + TEST_DATA_T a[dim], b[dim], c[dim], d[dim], e[dim], zero[dim]; + std::vector a_vec(dim, (TEST_DATA_T)1); + std::vector b_vec(dim, (TEST_DATA_T)2); + std::vector c_vec(dim, (TEST_DATA_T)4); + std::vector d_vec(dim, (TEST_DATA_T)5); + std::vector zero_vec(dim, (TEST_DATA_T)0); + + auto inserted_vectors = std::vector>{a_vec, b_vec, c_vec, d_vec}; + + // Search for vectors when the index is empty. + runTopKSearchTest(index, a_vec.data(), 1, nullptr); + + // Add vectors. + VecSimIndex_AddVector(index, a_vec.data(), 0); + ASSERT_EQ(VecSimIndex_IndexSize(index), 1); + VecSimIndex_AddVector(index, b_vec.data(), 1); + VecSimIndex_AddVector(index, c_vec.data(), 2); + VecSimIndex_AddVector(index, d_vec.data(), 3); + ASSERT_EQ(VecSimIndex_IndexSize(index), 4); + + mock_thread_pool.thread_pool_join(); + EXPECT_EQ(mock_thread_pool.jobQ.size(), 0); + // Callbacks for verifying results. + auto ver_res_0 = [&](size_t id, double score, size_t index) { + ASSERT_EQ(id, index); + ASSERT_DOUBLE_EQ(score, dim * inserted_vectors[id][0] * inserted_vectors[id][0]); + }; + size_t result_c[] = {2, 3, 1, 0}; // Order of results for query on c. + auto ver_res_c = [&](size_t id, double score, size_t index) { + ASSERT_EQ(id, result_c[index]); + double dist = inserted_vectors[id][0] - c_vec[0]; + ASSERT_DOUBLE_EQ(score, dim * dist * dist); + }; + + auto k = 4; + runTopKSearchTest(index, zero_vec.data(), k, ver_res_0); + runTopKSearchTest(index, c_vec.data(), k, ver_res_c); +} + +TYPED_TEST(RaftIvfTieredTest, transferJob) { + // Create RAFT Tiered index instance with a mock queue. + + size_t dim = 4; + size_t flat_buffer_limit = 3; + size_t nLists = 1; + + VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + auto allocator = tiered_index->getAllocator(); + + VecSimQueryParams queryParams = {.batchSize = 1}; + + // Create a vector and add it to the tiered index. + labelType vec_label = 1; + TEST_DATA_T vector[dim]; + GenerateVector(vector, dim, vec_label); + VecSimIndex_AddVector(tiered_index, vector, vec_label); + ASSERT_EQ(tiered_index->indexSize(), 1); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 1); + ASSERT_EQ(tiered_index->frontendIndex->getDistanceFrom_Unsafe(vec_label, vector), 0); + + // Execute the insert job manually (in a synchronous manner). + ASSERT_EQ(mock_thread_pool.jobQ.size(), 1); + auto *insertion_job = reinterpret_cast(mock_thread_pool.jobQ.front().job); + ASSERT_EQ(insertion_job->jobType, RAFT_TRANSFER_JOB); + + mock_thread_pool.thread_iteration(); + ASSERT_EQ(tiered_index->indexSize(), 1); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), 1); + // RAFT IVF index should have allocated a single block, while flat index should remove the + // block. + ASSERT_EQ(tiered_index->frontendIndex->indexCapacity(), 0); + // After the execution, the job should be removed from the job queue. + ASSERT_EQ(mock_thread_pool.jobQ.size(), 0); +} + +TYPED_TEST(RaftIvfTieredTest, transferJobAsync) { + size_t dim = 32; + size_t n = 500; + size_t nLists = 120; + size_t flat_buffer_limit = 160; + + size_t k = 1; + + // Create RaftIvfTiered index instance with a mock queue. + VecSimParams params = createDefaultFlatParams(dim, nLists, 20); + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + + // Launch the BG threads loop that takes jobs from the queue and executes them. + mock_thread_pool.init_threads(); + // Insert vectors + for (size_t i = 0; i < n; i++) { + GenerateAndAddVector(tiered_index, dim, i, i); + } + + mock_thread_pool.thread_pool_join(); + // Verify that the vectors were inserted to RaftIvf as expected, that the jobqueue is empty, + ASSERT_EQ(tiered_index->indexSize(), n); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), n); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); + ASSERT_EQ(mock_thread_pool.jobQ.size(), 0); + // Verify that the vectors were inserted to RaftIvf as expected + for (size_t i = 0; i < size_t{n / 10}; i++) { + TEST_DATA_T expected_vector[dim]; + GenerateVector(expected_vector, dim, i); + VecSimQueryReply *res = VecSimIndex_TopKQuery(tiered_index->backendIndex, expected_vector, + k, nullptr, BY_SCORE); + ASSERT_EQ(VecSimQueryReply_GetCode(res), VecSim_QueryReply_OK); + ASSERT_EQ(VecSimQueryReply_Len(res), k); + ASSERT_EQ(res->results[0].id, i); + ASSERT_EQ(res->results[0].score, 0); + VecSimQueryReply_Free(res); + } +} + +TYPED_TEST(RaftIvfTieredTest, transferJob_inplace) { + size_t dim = 32; + size_t n = 200; + size_t nLists = 120; + size_t flat_buffer_limit = 160; + + size_t k = 1; + + // Create RaftIvfTiered index instance with a mock queue. + VecSimParams params = createDefaultFlatParams(dim, nLists, 20); + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + + // In the absence of BG threads to takes jobs from the queue, the tiered index should + // transfer in place when flat_buffer is over the limit. + for (size_t i = 0; i < n; i++) { + GenerateAndAddVector(tiered_index, dim, i, i); + } + + ASSERT_EQ(tiered_index->indexSize(), n); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), flat_buffer_limit); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), n - flat_buffer_limit); + + // Run another batch of insertion. The tiered index should transfer inplace again. + for (size_t i = n; i < n * 2; i++) { + GenerateAndAddVector(tiered_index, dim, i, i); + } + ASSERT_EQ(tiered_index->indexSize(), 2 * n); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), flat_buffer_limit * 2); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 2 * (n - flat_buffer_limit)); +} + +TYPED_TEST(RaftIvfTieredTest, deleteVector_backend) { + size_t dim = 32; + size_t n = 500; + size_t nLists = 120; + size_t nDelete = 10; + size_t flat_buffer_limit = 1000; + + size_t k = 1; + + // Create RaftIvfTiered index instance with a mock queue. + VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + + labelType vec_label = 0; + // Delete from an empty index. + ASSERT_EQ(VecSimIndex_DeleteVector(tiered_index, vec_label), 0); + + // Launch the BG threads loop that takes jobs from the queue and executes them. + mock_thread_pool.init_threads(); + // Insert vectors + for (size_t i = 0; i < n; i++) { + GenerateAndAddVector(tiered_index, dim, i, i); + } + // Use just one thread to transfer all the vectors + mock_thread_pool.thread_pool_wait(100); + + // Check that the backend index has the first 12 vectors + ASSERT_EQ(tiered_index->indexSize(), n); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), n); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); + for (size_t i = 0; i < nDelete + 2; i++) { + TEST_DATA_T expected_vector[dim]; + GenerateVector(expected_vector, dim, i); + VecSimQueryReply *res = VecSimIndex_TopKQuery(tiered_index->backendIndex, expected_vector, + k, nullptr, BY_SCORE); + ASSERT_EQ(VecSimQueryReply_GetCode(res), VecSim_QueryReply_OK); + ASSERT_EQ(VecSimQueryReply_Len(res), k); + ASSERT_EQ(res->results[0].id, i); + ASSERT_EQ(res->results[0].score, 0); + VecSimQueryReply_Free(res); + } + + // Delete 10 first vectors + for (size_t i = 0; i < nDelete; i++) { + VecSimIndex_DeleteVector(tiered_index, i); + } + + ASSERT_EQ(tiered_index->indexSize(), n - nDelete); + ASSERT_EQ(tiered_index->backendIndex->indexSize(), n - nDelete); + ASSERT_EQ(tiered_index->frontendIndex->indexSize(), 0); +} + +TYPED_TEST(RaftIvfTieredTest, searchMetricCosine) { + size_t dim = 32; + size_t n = 25; + size_t nLists = 5; + size_t flat_buffer_limit = 100; + + size_t k = 10; + + // Create RaftIvfTiered index instance with a mock queue. + VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); + + // Set the metric to cosine. + params.algoParams.raftIvfParams.metric = VecSimMetric_Cosine; + + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + + // Launch the BG threads loop that takes jobs from the queue and executes them. + mock_thread_pool.init_threads(); + std::vector> inserted_vectors; + + for (size_t i = 0; i < n; i++) { + inserted_vectors.push_back(std::vector(dim)); + // Generate vectors + for (size_t j = 0; j < dim; j++) { + inserted_vectors.back()[j] = (TEST_DATA_T)i + j; + } + // Insert vectors + VecSimIndex_AddVector(tiered_index, inserted_vectors.back().data(), i); + } + mock_thread_pool.thread_pool_wait(100); + + // The query is a vector with half of the values equal to 8.1 and the other half equal to 1.1. + TEST_DATA_T query[dim]; + TEST_DATA_T query_norm[dim]; + GenerateVector(query, dim / 2, 8.1f); + GenerateVector(query + dim / 2, dim / 2, 1.1f); + memcpy(query_norm, query, dim * sizeof(TEST_DATA_T)); + VecSim_Normalize(query_norm, dim, VecSimType_FLOAT32); + + auto verify_cb = [&](size_t id, double score, size_t index) { + TEST_DATA_T neighbor_norm[dim]; + memcpy(neighbor_norm, inserted_vectors[id].data(), dim * sizeof(TEST_DATA_T)); + VecSim_Normalize(neighbor_norm, dim, VecSimType_FLOAT32); + + // Use distance function of the bruteforce index to verify the score. + double dist = tiered_index->frontendIndex->getDistFunc()( + query_norm, + neighbor_norm, + dim); + ASSERT_NEAR(score, dist, 1e-5); + }; + + runTopKSearchTest(tiered_index, query, k, verify_cb); +} + +TYPED_TEST(RaftIvfTieredTest, searchMetricIP) { + size_t dim = 4; + size_t n = 25; + size_t nLists = 5; + size_t flat_buffer_limit = 100; + + size_t k = 10; + + // Create RaftIvfTiered index instance with a mock queue. + VecSimParams params = createDefaultFlatParams(dim, nLists, nLists); + + // Set the metric to Inner Product. + params.algoParams.raftIvfParams.metric = VecSimMetric_IP; + + auto mock_thread_pool = tieredIndexMock(); + auto *tiered_index = this->createTieredIndex(¶ms, mock_thread_pool, flat_buffer_limit); + + // Launch the BG threads loop that takes jobs from the queue and executes them. + mock_thread_pool.init_threads(); + std::vector> inserted_vectors; + + for (size_t i = 0; i < n; i++) { + inserted_vectors.push_back(std::vector(dim)); + // Generate vectors + for (size_t j = 0; j < dim; j++) { + inserted_vectors.back()[j] = (TEST_DATA_T)i + j; + } + // Insert vectors + VecSimIndex_AddVector(tiered_index, inserted_vectors.back().data(), i); + } + mock_thread_pool.thread_pool_wait(100); + + // The query is a vector with half of the values equal to 1.1 and the other half equal to 0.1. + TEST_DATA_T query[dim] = {1.1f, 1.1f, 0.1f, 0.1f}; + + auto verify_cb = [&](size_t id, double score, size_t index) { + // Use distance function of the bruteforce index to verify the score. + double dist = tiered_index->frontendIndex->getDistFunc()( + query, + inserted_vectors[id].data(), + dim); + ASSERT_NEAR(score, dist, 1e-5); + }; + + runTopKSearchTest(tiered_index, query, k, verify_cb); +}