diff --git a/AnyBuildLogs/latest.txt b/AnyBuildLogs/latest.txt index 38b4a947f..7d2f2ab9a 100644 --- a/AnyBuildLogs/latest.txt +++ b/AnyBuildLogs/latest.txt @@ -1 +1 @@ -20231019-111207-d314f8bf \ No newline at end of file +20260110-171520-4f890e5e \ No newline at end of file diff --git a/include/abstract_index.h b/include/abstract_index.h index 175509552..525759e34 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -71,11 +71,13 @@ class AbstractIndex // IDtype is either uint32_t or uint64_t template std::pair search(const data_type *query, const size_t K, const uint32_t L, IDType *indices, - float *distances = nullptr); + float *distances = nullptr, + std::function rerank_fn = nullptr); template std::pair diverse_search(const data_type* query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IDType* indices, - float* distances = nullptr); + float* distances = nullptr, + std::function rerank_fn = nullptr); // Filter support search // IndexType is either uint32_t or uint64_t @@ -83,7 +85,8 @@ class AbstractIndex std::pair search_with_filters(const DataType &query, const std::vector &raw_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IndexType *indices, - float *distances); + float *distances, + std::function rerank_fn = nullptr); // insert points with labels, labels should be present for filtered index template @@ -122,12 +125,15 @@ class AbstractIndex private: virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0; virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, - std::any &indices, float *distances = nullptr) = 0; + std::any &indices, float *distances = nullptr, + std::function rerank_fn = nullptr) = 0; virtual std::pair _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, - std::any& indices, float* distances = nullptr) = 0; + std::any& indices, float* distances = nullptr, + std::function rerank_fn = nullptr) = 0; virtual std::pair _search_with_filters(const DataType &query, const std::vector &filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices, - float *distances) = 0; + float *distances, + std::function rerank_fn) = 0; virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector &labels) = 0; virtual int _insert_point(const DataType &data_point, const TagType tag) = 0; virtual int _lazy_delete(const TagType &tag) = 0; diff --git a/include/defaults.h b/include/defaults.h index 240a57b62..eb646a494 100644 --- a/include/defaults.h +++ b/include/defaults.h @@ -34,5 +34,9 @@ const uint32_t SEARCH_LIST_SIZE = 100; const bool DIVERSE_INDEX = false; const std::string EMPTY_STRING = ""; const bool NUM_DIVERSE_BUILD = 1; + +const bool REORDER_INDEX = false; +const uint32_t REORDER_DIM = 0; + } // namespace defaults } // namespace diskann diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index 489f1a729..c32a39f0a 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +#pragma once #include #include @@ -72,7 +73,7 @@ template class InMemDataStore : public AbstractDataStore +class InMemReorderDataStore : public InMemDataStore +{ +public: + InMemReorderDataStore(location_t capacity, size_t search_dim, size_t reorder_dim, + std::unique_ptr> search_distance_fn); + + virtual ~InMemReorderDataStore() = default; + + virtual size_t get_dims() const override; + + virtual size_t get_aligned_dim() const override; + + virtual void get_vector(const location_t i, data_t* target) const override; + + virtual void set_vector(const location_t i, const data_t* const vector) override + { + throw std::runtime_error("set_vector not supported in InMemReorderDataStore"); + } + + virtual void prefetch_vector(const location_t loc) const override; + + virtual float get_distance(const data_t* preprocessed_query, const location_t loc) const override; + + virtual void get_distance(const data_t* preprocessed_query, const location_t* locations, + const uint32_t location_count, float* distances, + AbstractScratch* scratch) const override; + + virtual float get_distance(const location_t loc1, const location_t loc2) const override; + + virtual void get_distance(const data_t* preprocessed_query, const std::vector& ids, + std::vector& distances, AbstractScratch* scratch_space) const override; + + size_t get_reorder_aligned_dim() const; + + void get_reorder_vector(const location_t i, data_t *target) const; + + const data_t* get_reorder_vector(const location_t i) const; + +protected: + virtual location_t expand(const location_t new_size) override + { + throw std::runtime_error("expand not supported in InMemReorderDataStore"); + } + + virtual location_t shrink(const location_t new_size) override + { + throw std::runtime_error("shrink not supported in InMemReorderDataStore"); + } + +private: + size_t _search_dim = 0; + size_t _search_aligned_dim = 0; + +}; +} \ No newline at end of file diff --git a/include/index.h b/include/index.h index b4eb6f681..59577e53c 100644 --- a/include/index.h +++ b/include/index.h @@ -141,11 +141,13 @@ template clas // can customize L on a per-query basis without tampering with "Parameters" template DISKANN_DLLEXPORT std::pair search(const T *query, const size_t K, const uint32_t L, - IDType *indices, float *distances = nullptr, const uint32_t maxLperSeller = 0); + IDType *indices, float *distances = nullptr, const uint32_t maxLperSeller = 0, + std::function rerank_fn = nullptr); template std::pair diverse_search(const T* query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IDType* indices, - float* distances = nullptr); + float* distances = nullptr, + std::function rerank_fn = nullptr); // Initialize space for res_vectors before calling. DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, @@ -153,13 +155,15 @@ template clas const std::vector& filter_labels); virtual std::pair _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, - std::any& indices, float* distances = nullptr) override; + std::any& indices, float* distances = nullptr, + std::function rerank_fn = nullptr) override; // Filter support search template DISKANN_DLLEXPORT std::pair search_with_filters(const T *query, const std::vector &filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, - IndexType *indices, float *distances); + IndexType *indices, float *distances, + std::function rerank_fn = nullptr); // Will fail if tag already in the index or if tag=0. DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag); @@ -220,11 +224,13 @@ template clas virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) override; virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, - std::any &indices, float *distances = nullptr) override; + std::any &indices, float *distances = nullptr, + std::function rerank_fn = nullptr) override; virtual std::pair _search_with_filters(const DataType &query, const std::vector &filter_labels_raw, const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices, - float *distances) override; + float *distances, + std::function rerank_fn = nullptr) override; virtual int _insert_point(const DataType &data_point, const TagType tag) override; virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector &labels) override; @@ -328,6 +334,10 @@ template clas // Acquire exclusive _update_lock and _tag_lock before calling. void resize(size_t new_max_points); + template + void post_process_search_results(InMemQueryScratch *scratch, const size_t K, IDType *indices, float *distances, + std::function rerank_fn); + // Acquire unique lock on _update_lock, _consolidate_lock, _tag_lock // and _delete_lock before calling these functions. // Renumber nodes, update tag and location maps and compact the @@ -477,6 +487,9 @@ template clas bool _use_integer_labels = false; integer_label_vector _label_vector; + + bool _reorder_index = false; + uint32_t _search_dim = 0; // only used when _reorder_index = true TableStats _table_stats; diff --git a/include/index_config.h b/include/index_config.h index 2ff46260d..28cacf95d 100644 --- a/include/index_config.h +++ b/include/index_config.h @@ -5,7 +5,8 @@ namespace diskann { enum class DataStoreStrategy { - MEMORY + MEMORY, + REORDER_MEMORY }; enum class GraphStoreStrategy @@ -39,6 +40,9 @@ struct IndexConfig std::string tag_type; std::string data_type; + bool reorder_index; + uint32_t search_dim; + // Params for building index std::shared_ptr index_write_params; // Params for searching index @@ -49,12 +53,13 @@ struct IndexConfig size_t max_points, size_t num_pq_chunks, size_t num_frozen_points, bool dynamic_index, bool enable_tags, bool pq_dist_build, bool concurrent_consolidate, bool use_opq, bool filtered_index, std::string &data_type, const std::string &tag_type, const std::string &label_type, + bool reorder_index, uint32_t search_dim, std::shared_ptr index_write_params, std::shared_ptr index_search_params) : data_strategy(data_strategy), graph_strategy(graph_strategy), metric(metric), dimension(dimension), max_points(max_points), dynamic_index(dynamic_index), enable_tags(enable_tags), pq_dist_build(pq_dist_build), concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), filtered_index(filtered_index), - num_pq_chunks(num_pq_chunks), num_frozen_pts(num_frozen_points), label_type(label_type), tag_type(tag_type), + num_pq_chunks(num_pq_chunks), num_frozen_pts(num_frozen_points), reorder_index(reorder_index), search_dim(search_dim), label_type(label_type), tag_type(tag_type), data_type(data_type), index_write_params(index_write_params), index_search_params(index_search_params) { } @@ -163,6 +168,18 @@ class IndexConfigBuilder return *this; } + IndexConfigBuilder &with_reorder_index(bool reorder_index) + { + this->_reorder_index = reorder_index; + return *this; + } + + IndexConfigBuilder &with_search_dim(uint32_t search_dim) + { + this->_search_dim = search_dim; + return *this; + } + IndexConfigBuilder &with_index_write_params(IndexWriteParameters &index_write_params) { this->_index_write_params = std::make_shared(index_write_params); @@ -222,8 +239,8 @@ class IndexConfigBuilder return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, _max_points, _num_pq_chunks, _num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate, - _use_opq, _filtered_index, _data_type, _tag_type, _label_type, _index_write_params, - _index_search_params); + _use_opq, _filtered_index, _data_type, _tag_type, _label_type, _reorder_index, _search_dim, + _index_write_params, _index_search_params); } IndexConfigBuilder(const IndexConfigBuilder &) = delete; @@ -251,6 +268,9 @@ class IndexConfigBuilder std::string _tag_type{"uint32"}; std::string _data_type; + bool _reorder_index = false; + uint32_t _search_dim = 0; + std::shared_ptr _index_write_params; std::shared_ptr _index_search_params; }; diff --git a/include/index_factory.h b/include/index_factory.h index 80bc40dba..140643b53 100644 --- a/include/index_factory.h +++ b/include/index_factory.h @@ -15,9 +15,11 @@ class IndexFactory const GraphStoreStrategy stratagy, const size_t size, const size_t reserve_graph_degree); template - DISKANN_DLLEXPORT static std::shared_ptr> construct_datastore(DataStoreStrategy stratagy, - size_t num_points, - size_t dimension, Metric m); + DISKANN_DLLEXPORT static std::shared_ptr> construct_datastore(const IndexConfig& index_config); + + template + DISKANN_DLLEXPORT static std::shared_ptr> construct_mem_datastore(size_t total_internal_points, size_t dimension, + Metric metric); // For now PQDataStore incorporates within itself all variants of quantization that we support. In the // future it may be necessary to introduce an AbstractPQDataStore class to spearate various quantization // flavours. diff --git a/include/parameters.h b/include/parameters.h index 01a8b834c..2b08ade86 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -20,7 +20,6 @@ enum class LabelFormatType :uint8_t }; class IndexWriteParameters - { public: const uint32_t search_list_size; // L diff --git a/include/scratch.h b/include/scratch.h index 7fe3e0da8..ce30f1249 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -106,6 +106,10 @@ template class InMemQueryScratch : public AbstractScratch { return _query_label_bitmask; } + inline std::vector& reranked_results() + { + return _reranked_results; + } private: uint32_t _L; @@ -149,6 +153,8 @@ template class InMemQueryScratch : public AbstractScratch std::vector _occlude_list_output; // bitmask buffer in searching time std::vector _query_label_bitmask; + // Buffer for reranking results to avoid repeated allocations + std::vector _reranked_results; }; // diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index 2084bac6d..79b1dd7de 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -15,20 +15,21 @@ void AbstractIndex::build(const data_type *data, const size_t num_points_to_load template std::pair AbstractIndex::search(const data_type *query, const size_t K, const uint32_t L, - IDType *indices, float *distances) + IDType *indices, float *distances, + std::function rerank_fn) { auto any_indices = std::any(indices); auto any_query = std::any(query); - return _search(any_query, K, L, any_indices, distances); + return _search(any_query, K, L, any_indices, distances, std::move(rerank_fn)); } template std::pair AbstractIndex::diverse_search(const data_type* query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, - IDType* indices, float* distances) + IDType* indices, float* distances, std::function rerank_fn) { auto any_indices = std::any(indices); auto any_query = std::any(query); - return _diverse_search(any_query, K, L, maxLperSeller, any_indices, distances); + return _diverse_search(any_query, K, L, maxLperSeller, any_indices, distances, std::move(rerank_fn)); } template @@ -45,10 +46,10 @@ size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, template std::pair AbstractIndex::search_with_filters(const DataType &query, const std::vector &raw_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IndexType *indices, - float *distances) + float *distances, std::function rerank_fn) { auto any_indices = std::any(indices); - return _search_with_filters(query, raw_labels, K, L, maxLperSeller, any_indices, distances); + return _search_with_filters(query, raw_labels, K, L, maxLperSeller, any_indices, distances, std::move(rerank_fn)); } template @@ -159,40 +160,51 @@ template DISKANN_DLLEXPORT void AbstractIndex::build(const const std::vector& tags); template DISKANN_DLLEXPORT std::pair AbstractIndex::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( const DataType &query, const std::vector &raw_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( const DataType &query, const std::vector& raw_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); - + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( - const float* query, const size_t K, const uint32_t L, const uint32_t maxL, uint32_t* indices, float* distances); + const float* query, const size_t K, const uint32_t L, const uint32_t maxL, uint32_t* indices, float* distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( - const uint8_t* query, const size_t K, const uint32_t L, const uint32_t maxL, uint32_t* indices, float* distances); + const uint8_t* query, const size_t K, const uint32_t L, const uint32_t maxL, uint32_t* indices, float* distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( - const int8_t* query, const size_t K, const uint32_t L, const uint32_t maxL, uint32_t* indices, float* distances); + const int8_t* query, const size_t K, const uint32_t L, const uint32_t maxL, uint32_t* indices, float* distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( - const float* query, const size_t K, const uint32_t L, const uint32_t maxL, uint64_t* indices, float* distances); + const float* query, const size_t K, const uint32_t L, const uint32_t maxL, uint64_t* indices, float* distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( - const uint8_t* query, const size_t K, const uint32_t L, const uint32_t maxL, uint64_t* indices, float* distances); + const uint8_t* query, const size_t K, const uint32_t L, const uint32_t maxL, uint64_t* indices, float* distances, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( - const int8_t* query, const size_t K, const uint32_t L, const uint32_t maxL, uint64_t* indices, float* distances); + const int8_t* query, const size_t K, const uint32_t L, const uint32_t maxL, uint64_t* indices, float* distances, + std::function rerank_fn); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags( const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, diff --git a/src/in_mem_data_store.cpp b/src/in_mem_data_store.cpp index 0d62b1a45..7bce11fb4 100644 --- a/src/in_mem_data_store.cpp +++ b/src/in_mem_data_store.cpp @@ -112,7 +112,7 @@ template location_t InMemDataStore::load_impl(const st template size_t InMemDataStore::save(const std::string &filename, const location_t num_points) { - return save_data_in_base_dimensions(filename, _data, num_points, this->get_dims(), this->get_aligned_dim(), 0U); + return save_data_in_base_dimensions(filename, _data, num_points, this->_dim, this->_aligned_dim, 0U); } template void InMemDataStore::populate_data(const data_t *vectors, const location_t num_pts) @@ -143,11 +143,11 @@ template void InMemDataStore::populate_data(const std: throw diskann::ANNException(ss.str(), -1); } - if ((location_t)ndim != this->get_dims()) + if ((location_t)ndim != this->_dim) { std::stringstream ss; ss << "Number of dimensions of a point in the file: " << filename - << " is not equal to dimensions of data store: " << this->capacity() << "." << std::endl; + << " is not equal to dimensions of data store: " << this->_dim << "." << std::endl; throw diskann::ANNException(ss.str(), -1); } @@ -160,7 +160,7 @@ template void InMemDataStore::populate_data(const std: template void InMemDataStore::extract_data_to_bin(const std::string &filename, const location_t num_points) { - save_data_in_base_dimensions(filename, _data, num_points, this->get_dims(), this->get_aligned_dim(), 0U); + save_data_in_base_dimensions(filename, _data, num_points, this->_dim, this->_aligned_dim, 0U); } template void InMemDataStore::get_vector(const location_t i, data_t *dest) const diff --git a/src/in_mem_reorder_data_store.cpp b/src/in_mem_reorder_data_store.cpp new file mode 100644 index 000000000..96836ae87 --- /dev/null +++ b/src/in_mem_reorder_data_store.cpp @@ -0,0 +1,93 @@ +#include "in_mem_reorder_data_store.h" + +namespace diskann +{ + +template +InMemReorderDataStore::InMemReorderDataStore(location_t capacity, size_t search_dim, size_t reorder_dim, + std::unique_ptr> search_distance_fn) + : InMemDataStore(capacity, reorder_dim, std::move(search_distance_fn)) +{ + _search_dim = search_dim; + _search_aligned_dim = ROUND_UP(search_dim, this->_distance_fn->get_required_alignment()); +} + +template size_t InMemReorderDataStore::get_dims() const +{ + return _search_dim; +} + +template size_t InMemReorderDataStore::get_aligned_dim() const +{ + return _search_aligned_dim; +} + +template +void InMemReorderDataStore::get_vector(const location_t i, data_t* target) const +{ + memcpy(target, this->_data + i * this->_aligned_dim, _search_dim * sizeof(data_t)); +} + +template void InMemReorderDataStore::prefetch_vector(const location_t loc) const +{ + diskann::prefetch_vector((const char*)this->_data + this->_aligned_dim * (size_t)loc * sizeof(data_t), + sizeof(data_t) * _search_aligned_dim); +} + +template float InMemReorderDataStore::get_distance(const data_t* preprocessed_query, const location_t loc) const +{ + return this->_distance_fn->compare(preprocessed_query, this->_data + this->_aligned_dim * loc, (uint32_t)_search_aligned_dim); +} + +template +void InMemReorderDataStore::get_distance(const data_t* preprocessed_query, const location_t* locations, + const uint32_t location_count, float* distances, + AbstractScratch* scratch) const +{ + for (location_t i = 0; i < location_count; i++) + { + distances[i] = this->_distance_fn->compare(preprocessed_query, this->_data + locations[i] * this->_aligned_dim, (uint32_t)this->_search_aligned_dim); + } +} + +template +float InMemReorderDataStore::get_distance(const location_t loc1, const location_t loc2) const +{ + return this->_distance_fn->compare(this->_data + loc1 * this->_aligned_dim, this->_data + loc2 * this->_aligned_dim, + (uint32_t)this->_search_aligned_dim); +} + +template +void InMemReorderDataStore::get_distance(const data_t* preprocessed_query, const std::vector& ids, + std::vector& distances, AbstractScratch* scratch_space) const +{ + for (int i = 0; i < ids.size(); i++) + { + distances[i] = + this->_distance_fn->compare(preprocessed_query, this->_data + ids[i] * this->_aligned_dim, (uint32_t)this->_search_aligned_dim); + } +} + +template +size_t InMemReorderDataStore::get_reorder_aligned_dim() const +{ + return this->_aligned_dim; +} + +template +void InMemReorderDataStore::get_reorder_vector(const location_t i, data_t* target) const +{ + memcpy(target, this->_data + i * this->_aligned_dim, this->_aligned_dim * sizeof(data_t)); +} + +template +const data_t* InMemReorderDataStore::get_reorder_vector(const location_t i) const +{ + return this->_data + i * this->_aligned_dim; +} + +template DISKANN_DLLEXPORT class InMemReorderDataStore; +template DISKANN_DLLEXPORT class InMemReorderDataStore; +template DISKANN_DLLEXPORT class InMemReorderDataStore ; + +} \ No newline at end of file diff --git a/src/index.cpp b/src/index.cpp index f18b9966f..8c4df137c 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -17,6 +17,7 @@ #include "label_helper.h" #include "color_helper.h" #include "filter_match_proxy.h" +#include "in_mem_reorder_data_store.h" #if defined(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" @@ -95,6 +96,9 @@ Index::Index(const IndexConfig &index_config, std::shared_ptrsearch_list_size; @@ -143,8 +147,7 @@ Index::Index(Metric m, const size_t dim, const size_t max_point .is_filtered(filtered_index) .with_data_type(diskann_type_to_name()) .build(), - IndexFactory::construct_datastore(DataStoreStrategy::MEMORY, - (max_points == 0 ? (size_t)1 : max_points), + IndexFactory::construct_mem_datastore((max_points == 0 ? (size_t)1 : max_points), dim, m), IndexFactory::construct_graphstore(GraphStoreStrategy::MEMORY, (max_points == 0 ? (size_t)1 : max_points), @@ -2367,7 +2370,8 @@ void Index::build_filtered_index(const char *filename, const st template std::pair Index::_search(const DataType &query, const size_t K, const uint32_t L, - std::any &indices, float *distances) + std::any &indices, float *distances, + std::function rerank_fn) { try { @@ -2375,12 +2379,12 @@ std::pair Index::_search(const DataType &qu if (typeid(uint32_t *) == indices.type()) { auto u32_ptr = std::any_cast(indices); - return this->search(typed_query, K, L, u32_ptr, distances); + return this->search(typed_query, K, L, u32_ptr, distances, 0, std::move(rerank_fn)); } else if (typeid(uint64_t *) == indices.type()) { auto u64_ptr = std::any_cast(indices); - return this->search(typed_query, K, L, u64_ptr, distances); + return this->search(typed_query, K, L, u64_ptr, distances, 0, std::move(rerank_fn)); } else { @@ -2399,7 +2403,7 @@ std::pair Index::_search(const DataType &qu template std::pair Index::_diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, - std::any& indices, float* distances) + std::any& indices, float* distances, std::function rerank_fn) { try { @@ -2407,12 +2411,12 @@ std::pair Index::_diverse_search(const Data if (typeid(uint32_t*) == indices.type()) { auto u32_ptr = std::any_cast(indices); - return this->search(typed_query, K, L, u32_ptr, distances, maxLperSeller); + return this->search(typed_query, K, L, u32_ptr, distances, maxLperSeller, rerank_fn); } else if (typeid(uint64_t*) == indices.type()) { auto u64_ptr = std::any_cast(indices); - return this->search(typed_query, K, L, u64_ptr, distances, maxLperSeller); + return this->search(typed_query, K, L, u64_ptr, distances, maxLperSeller, rerank_fn); } else { @@ -2432,7 +2436,8 @@ std::pair Index::_diverse_search(const Data template template std::pair Index::search(const T *query, const size_t K, const uint32_t L, - IdType *indices, float *distances, const uint32_t maxLperSeller) + IdType *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn) { if (K > (uint64_t)L) { @@ -2459,41 +2464,7 @@ std::pair Index::search(const T *query, con auto retval = iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true, maxLperSeller); - NeighborPriorityQueueBase* best_L_nodes; - if (!_diverse_index || maxLperSeller == 0) { - best_L_nodes = &(scratch->best_l_nodes()); - } - else { - best_L_nodes = &(scratch->best_diverse_nodes()); - } - - size_t pos = 0; - for (size_t i = 0; i < best_L_nodes->size(); ++i) - { - if ((*best_L_nodes)[i].id < _max_points) - { - // safe because Index uses uint32_t ids internally - // and IDType will be uint32_t or uint64_t - indices[pos] = (IdType)(*best_L_nodes)[i].id; - if (distances != nullptr) - { -#ifdef EXEC_ENV_OLS - // DLVS expects negative distances - distances[pos] = best_L_nodes[i].distance; -#else - distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * ((*best_L_nodes)[i].distance) - : (*best_L_nodes)[i].distance; -#endif - } - pos++; - } - if (pos == K) - break; - } - if (pos < K) - { - diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl; - } + post_process_search_results(scratch, K, indices, distances, std::move(rerank_fn)); return retval; } @@ -2502,7 +2473,8 @@ template std::pair Index::_search_with_filters(const DataType &query, const std::vector &raw_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices, - float *distances) + float *distances, + std::function rerank_fn) { std::vector converted_labels; converted_labels.reserve(raw_labels.size()); @@ -2518,12 +2490,12 @@ std::pair Index::_search_with_filters(const if (typeid(uint64_t *) == indices.type()) { auto ptr = std::any_cast(indices); - return this->search_with_filters(std::any_cast(query), converted_labels, K, L, maxLperSeller, ptr, distances); + return this->search_with_filters(std::any_cast(query), converted_labels, K, L, maxLperSeller, ptr, distances, rerank_fn); } else if (typeid(uint32_t *) == indices.type()) { auto ptr = std::any_cast(indices); - return this->search_with_filters(std::any_cast(query), converted_labels, K, L, maxLperSeller, ptr, distances); + return this->search_with_filters(std::any_cast(query), converted_labels, K, L, maxLperSeller, ptr, distances, rerank_fn); } else { @@ -2535,7 +2507,8 @@ template template std::pair Index::search_with_filters(const T *query, const std::vector &filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, - IdType *indices, float *distances) + IdType *indices, float *distances, + std::function rerank_fn) { if (K > (uint64_t)L) { @@ -2591,6 +2564,16 @@ std::pair Index::search_with_filters(const _data_store->preprocess_query(query, scratch); auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true, maxLperSeller); + post_process_search_results(scratch, K, indices, distances, std::move(rerank_fn)); + + return retval; +} + +template +template +void Index::post_process_search_results(InMemQueryScratch* scratch, const size_t K, IDType* indices, float* distances, + std::function rerank_fn) +{ NeighborPriorityQueueBase* best_L_nodes; if (!_diverse_index) { best_L_nodes = &(scratch->best_l_nodes()); @@ -2599,34 +2582,85 @@ std::pair Index::search_with_filters(const best_L_nodes = &(scratch->best_diverse_nodes()); } - size_t pos = 0; - for (size_t i = 0; i < best_L_nodes->size(); ++i) + // If reorder_index is enabled and rerank function is provided, rerank the results + if (_reorder_index && rerank_fn != nullptr) { - if ((*best_L_nodes)[i].id < _max_points) - { - indices[pos] = (IdType)(*best_L_nodes)[i].id; + // Use pre-allocated buffer from scratch to avoid repeated allocations + auto& reranked_results = scratch->reranked_results(); + reranked_results.clear(); + + auto reorder_data_store = dynamic_cast*>(_data_store.get()); + // Rerank each candidate using the provided rerank function + for (size_t i = 0; i < best_L_nodes->size(); ++i) + { + if ((*best_L_nodes)[i].id < _max_points) + { + // Get the original full-dimensional vector for this point + const uint8_t* data_ptr = reinterpret_cast(reorder_data_store->get_reorder_vector((*best_L_nodes)[i].id)); + + // Calculate the reranked distance using the full-dimensional data + float reranked_distance = rerank_fn(data_ptr, reorder_data_store->get_reorder_aligned_dim() * sizeof(T)); + + reranked_results.emplace_back((*best_L_nodes)[i].id, reranked_distance); + } + } + + // Sort by reranked distance + std::sort(reranked_results.begin(), reranked_results.end()); + + // Fill the output arrays with top K results + size_t num_results = std::min(K, reranked_results.size()); + for (size_t i = 0; i < num_results; ++i) + { + indices[i] = static_cast(reranked_results[i].id); if (distances != nullptr) { #ifdef EXEC_ENV_OLS // DLVS expects negative distances - distances[pos] = (*best_L_nodes)[i].distance; + distances[i] = reranked_results[i].distance; #else - distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * (*best_L_nodes)[i].distance - : (*best_L_nodes)[i].distance; + distances[i] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * reranked_results[i].distance + : reranked_results[i].distance; #endif } - pos++; } - if (pos == K) - break; + + if (num_results < K) + { + diskann::cerr << "Found fewer than K elements for query after reranking" << std::endl; + } } - if (pos < K) + else { - diskann::cerr << "Found fewer than K elements for query" << std::endl; - } + // Original process: copy results from best_L_nodes + size_t pos = 0; + for (size_t i = 0; i < best_L_nodes->size(); ++i) + { + if ((*best_L_nodes)[i].id < _max_points) + { + indices[pos] = static_cast((*best_L_nodes)[i].id); - return retval; + if (distances != nullptr) + { +#ifdef EXEC_ENV_OLS + // DLVS expects negative distances + distances[pos] = (*best_L_nodes)[i].distance; +#else + distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * (*best_L_nodes)[i].distance + : (*best_L_nodes)[i].distance; +#endif + } + pos++; + } + if (pos == K) + break; + } + if (pos < K) + { + diskann::cerr << "Found fewer than K elements for query" << std::endl; + } + } } template @@ -3851,130 +3885,154 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const std::vector &filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const float *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const uint8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const uint8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const int8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const int8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const float *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const uint8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const uint8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const int8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const int8_t *query, const std::vector& filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller, + std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const std::vector &filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const float *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const uint8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const uint8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const int8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const int8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const float *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const uint8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const uint8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const int8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, - float *distances); + float *distances, std::function rerank_fn); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const int8_t *query, const std::vector & filter_labels, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, - float *distances); + float *distances, std::function rerank_fn); } // namespace diskann diff --git a/src/index_factory.cpp b/src/index_factory.cpp index 8056bf4a0..55c856680 100644 --- a/src/index_factory.cpp +++ b/src/index_factory.cpp @@ -4,6 +4,7 @@ #include "in_mem_static_graph_store.h" #include "in_mem_graph_reformat_store.h" #include "in_mem_static_graph_reformat_store.h" +#include "in_mem_reorder_data_store.h" namespace diskann { @@ -67,17 +68,30 @@ template Distance *IndexFactory::construct_inmem_distance_fn(Met } template -std::shared_ptr> IndexFactory::construct_datastore(DataStoreStrategy strategy, - size_t total_internal_points, size_t dimension, - Metric metric) +std::shared_ptr> IndexFactory::construct_mem_datastore(size_t total_internal_points, size_t dimension, + Metric metric) { std::unique_ptr> distance; - switch (strategy) + distance.reset(construct_inmem_distance_fn(metric)); + return std::make_shared>((location_t)total_internal_points, dimension, + std::move(distance)); +} + +template +std::shared_ptr> IndexFactory::construct_datastore(const IndexConfig& index_config) +{ + std::unique_ptr> distance; + switch (index_config.data_strategy) { case DataStoreStrategy::MEMORY: - distance.reset(construct_inmem_distance_fn(metric)); - return std::make_shared>((location_t)total_internal_points, dimension, + distance.reset(construct_inmem_distance_fn(index_config.metric)); + return std::make_shared>((location_t)index_config.max_points, index_config.dimension, std::move(distance)); + case DataStoreStrategy::REORDER_MEMORY: + distance.reset(construct_inmem_distance_fn(index_config.metric)); + return std::make_shared>((location_t)index_config.max_points, + index_config.search_dim, index_config.dimension, + std::move(distance)); default: break; } @@ -131,7 +145,7 @@ std::unique_ptr IndexFactory::create_instance() size_t num_points = _config->max_points; size_t dim = _config->dimension; // auto graph_store = construct_graphstore(_config->graph_strategy, num_points); - auto data_store = construct_datastore(_config->data_strategy, num_points, dim, _config->metric); + auto data_store = construct_datastore(*_config); std::shared_ptr> pq_data_store = nullptr; if (_config->data_strategy == DataStoreStrategy::MEMORY && _config->pq_dist_build) @@ -217,11 +231,11 @@ std::unique_ptr IndexFactory::create_instance(const std::string & throw ANNException("Error: unsupported label_type please choose from [uint/ushort]", -1); } -// template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_datastore( -// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); -// template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_datastore( -// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); -// template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_datastore( -// DataStoreStrategy stratagy, size_t num_points, size_t dimension, Metric m); + template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_mem_datastore( + size_t num_points, size_t dimension, Metric m); + template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_mem_datastore( + size_t num_points, size_t dimension, Metric m); + template DISKANN_DLLEXPORT std::shared_ptr> IndexFactory::construct_mem_datastore( + size_t num_points, size_t dimension, Metric m); } // namespace diskann diff --git a/src/scratch.cpp b/src/scratch.cpp index f52f843c6..d1fcb7462 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -38,6 +38,7 @@ InMemQueryScratch::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, _inserted_into_pool_bs = new boost::dynamic_bitset<>(); _id_scratch.reserve((size_t)std::ceil(1.5 * defaults::GRAPH_SLACK_FACTOR * _R)); _dist_scratch.reserve((size_t)std::ceil(1.5 * defaults::GRAPH_SLACK_FACTOR * _R)); + _reranked_results.reserve(std::max(search_l, indexing_l)); resize_for_new_L(std::max(search_l, indexing_l)); @@ -64,6 +65,7 @@ template void InMemQueryScratch::clear() _expanded_nghrs_vec.clear(); _occlude_list_output.clear(); _query_label_bitmask.clear(); + _reranked_results.clear(); _best_diverse_nodes.clear(); }