Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion AnyBuildLogs/latest.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
20231019-111207-d314f8bf
20260110-171520-4f890e5e
18 changes: 12 additions & 6 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,22 @@ class AbstractIndex
// IDtype is either uint32_t or uint64_t
template <typename data_type, typename IDType>
std::pair<uint32_t, uint32_t> search(const data_type *query, const size_t K, const uint32_t L, IDType *indices,
float *distances = nullptr);
float *distances = nullptr,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);

template <typename data_type, typename IDType>
std::pair<uint32_t, uint32_t> diverse_search(const data_type* query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IDType* indices,
float* distances = nullptr);
float* distances = nullptr,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);

// Filter support search
// IndexType is either uint32_t or uint64_t
template <typename IndexType>
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::vector<std::string> &raw_labels,
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
IndexType *indices,
float *distances);
float *distances,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);

// insert points with labels, labels should be present for filtered index
template <typename data_type, typename tag_type>
Expand Down Expand Up @@ -122,12 +125,15 @@ class AbstractIndex
private:
virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0;
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
std::any &indices, float *distances = nullptr) = 0;
std::any &indices, float *distances = nullptr,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr) = 0;
virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
std::any& indices, float* distances = nullptr) = 0;
std::any& indices, float* distances = nullptr,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr) = 0;
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::vector<std::string> &filter_labels,
const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
float *distances) = 0;
float *distances,
std::function<float(const std::uint8_t*, size_t)> rerank_fn) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector<std::string> &labels) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag) = 0;
virtual int _lazy_delete(const TagType &tag) = 0;
Expand Down
4 changes: 4 additions & 0 deletions include/defaults.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,9 @@ const uint32_t SEARCH_LIST_SIZE = 100;
const bool DIVERSE_INDEX = false;
const std::string EMPTY_STRING = "";
const bool NUM_DIVERSE_BUILD = 1;

const bool REORDER_INDEX = false;
const uint32_t REORDER_DIM = 0;

} // namespace defaults
} // namespace diskann
3 changes: 2 additions & 1 deletion include/in_mem_data_store.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once

#include <shared_mutex>
#include <memory>
Expand Down Expand Up @@ -72,7 +73,7 @@ template <typename data_t> class InMemDataStore : public AbstractDataStore<data_
virtual location_t load_impl(AlignedFileReader &reader);
#endif

private:
protected:
data_t *_data = nullptr;

size_t _aligned_dim;
Expand Down
62 changes: 62 additions & 0 deletions include/in_mem_reorder_data_store.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once
#include "in_mem_data_store.h"

namespace diskann
{

template <typename data_t>
class InMemReorderDataStore : public InMemDataStore<data_t>
{
public:
InMemReorderDataStore(location_t capacity, size_t search_dim, size_t reorder_dim,
std::unique_ptr<Distance<data_t>> search_distance_fn);

virtual ~InMemReorderDataStore() = default;

virtual size_t get_dims() const override;

virtual size_t get_aligned_dim() const override;

virtual void get_vector(const location_t i, data_t* target) const override;

virtual void set_vector(const location_t i, const data_t* const vector) override
{
throw std::runtime_error("set_vector not supported in InMemReorderDataStore");
}

virtual void prefetch_vector(const location_t loc) const override;

virtual float get_distance(const data_t* preprocessed_query, const location_t loc) const override;

virtual void get_distance(const data_t* preprocessed_query, const location_t* locations,
const uint32_t location_count, float* distances,
AbstractScratch<data_t>* scratch) const override;

virtual float get_distance(const location_t loc1, const location_t loc2) const override;

virtual void get_distance(const data_t* preprocessed_query, const std::vector<location_t>& ids,
std::vector<float>& distances, AbstractScratch<data_t>* scratch_space) const override;

size_t get_reorder_aligned_dim() const;

void get_reorder_vector(const location_t i, data_t *target) const;

const data_t* get_reorder_vector(const location_t i) const;

protected:
virtual location_t expand(const location_t new_size) override
{
throw std::runtime_error("expand not supported in InMemReorderDataStore");
}

virtual location_t shrink(const location_t new_size) override
{
throw std::runtime_error("shrink not supported in InMemReorderDataStore");
}

private:
size_t _search_dim = 0;
size_t _search_aligned_dim = 0;

};
}
25 changes: 19 additions & 6 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,25 +141,29 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// can customize L on a per-query basis without tampering with "Parameters"
template <typename IDType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search(const T *query, const size_t K, const uint32_t L,
IDType *indices, float *distances = nullptr, const uint32_t maxLperSeller = 0);
IDType *indices, float *distances = nullptr, const uint32_t maxLperSeller = 0,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);

template <typename IDType>
std::pair<uint32_t, uint32_t> diverse_search(const T* query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IDType* indices,
float* distances = nullptr);
float* distances = nullptr,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);

// Initialize space for res_vectors before calling.
DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags,
float *distances, std::vector<T *> &res_vectors, bool use_filters,
const std::vector<std::string>& filter_labels);

virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
std::any& indices, float* distances = nullptr) override;
std::any& indices, float* distances = nullptr,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr) override;

// Filter support search
template <typename IndexType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const std::vector<LabelT> &filter_labels,
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
IndexType *indices, float *distances);
IndexType *indices, float *distances,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr);

// Will fail if tag already in the index or if tag=0.
DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag);
Expand Down Expand Up @@ -220,11 +224,13 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) override;

virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
std::any &indices, float *distances = nullptr) override;
std::any &indices, float *distances = nullptr,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr) override;
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
const std::vector<std::string> &filter_labels_raw, const size_t K,
const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
float *distances) override;
float *distances,
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr) override;

virtual int _insert_point(const DataType &data_point, const TagType tag) override;
virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector<std::string> &labels) override;
Expand Down Expand Up @@ -328,6 +334,10 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// Acquire exclusive _update_lock and _tag_lock before calling.
void resize(size_t new_max_points);

template <typename IDType>
void post_process_search_results(InMemQueryScratch<T> *scratch, const size_t K, IDType *indices, float *distances,
std::function<float(const std::uint8_t *, size_t)> rerank_fn);

// Acquire unique lock on _update_lock, _consolidate_lock, _tag_lock
// and _delete_lock before calling these functions.
// Renumber nodes, update tag and location maps and compact the
Expand Down Expand Up @@ -477,6 +487,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

bool _use_integer_labels = false;
integer_label_vector _label_vector;

bool _reorder_index = false;
uint32_t _search_dim = 0; // only used when _reorder_index = true

TableStats _table_stats;

Expand Down
28 changes: 24 additions & 4 deletions include/index_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ namespace diskann
{
enum class DataStoreStrategy
{
MEMORY
MEMORY,
REORDER_MEMORY
};

enum class GraphStoreStrategy
Expand Down Expand Up @@ -39,6 +40,9 @@ struct IndexConfig
std::string tag_type;
std::string data_type;

bool reorder_index;
uint32_t search_dim;

// Params for building index
std::shared_ptr<IndexWriteParameters> index_write_params;
// Params for searching index
Expand All @@ -49,12 +53,13 @@ struct IndexConfig
size_t max_points, size_t num_pq_chunks, size_t num_frozen_points, bool dynamic_index, bool enable_tags,
bool pq_dist_build, bool concurrent_consolidate, bool use_opq, bool filtered_index,
std::string &data_type, const std::string &tag_type, const std::string &label_type,
bool reorder_index, uint32_t search_dim,
std::shared_ptr<IndexWriteParameters> index_write_params,
std::shared_ptr<IndexSearchParams> index_search_params)
: data_strategy(data_strategy), graph_strategy(graph_strategy), metric(metric), dimension(dimension),
max_points(max_points), dynamic_index(dynamic_index), enable_tags(enable_tags), pq_dist_build(pq_dist_build),
concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), filtered_index(filtered_index),
num_pq_chunks(num_pq_chunks), num_frozen_pts(num_frozen_points), label_type(label_type), tag_type(tag_type),
num_pq_chunks(num_pq_chunks), num_frozen_pts(num_frozen_points), reorder_index(reorder_index), search_dim(search_dim), label_type(label_type), tag_type(tag_type),
data_type(data_type), index_write_params(index_write_params), index_search_params(index_search_params)
{
}
Expand Down Expand Up @@ -163,6 +168,18 @@ class IndexConfigBuilder
return *this;
}

IndexConfigBuilder &with_reorder_index(bool reorder_index)
{
this->_reorder_index = reorder_index;
return *this;
}

IndexConfigBuilder &with_search_dim(uint32_t search_dim)
{
this->_search_dim = search_dim;
return *this;
}

IndexConfigBuilder &with_index_write_params(IndexWriteParameters &index_write_params)
{
this->_index_write_params = std::make_shared<IndexWriteParameters>(index_write_params);
Expand Down Expand Up @@ -222,8 +239,8 @@ class IndexConfigBuilder

return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, _max_points, _num_pq_chunks,
_num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate,
_use_opq, _filtered_index, _data_type, _tag_type, _label_type, _index_write_params,
_index_search_params);
_use_opq, _filtered_index, _data_type, _tag_type, _label_type, _reorder_index, _search_dim,
_index_write_params, _index_search_params);
}

IndexConfigBuilder(const IndexConfigBuilder &) = delete;
Expand Down Expand Up @@ -251,6 +268,9 @@ class IndexConfigBuilder
std::string _tag_type{"uint32"};
std::string _data_type;

bool _reorder_index = false;
uint32_t _search_dim = 0;

std::shared_ptr<IndexWriteParameters> _index_write_params;
std::shared_ptr<IndexSearchParams> _index_search_params;
};
Expand Down
8 changes: 5 additions & 3 deletions include/index_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ class IndexFactory
const GraphStoreStrategy stratagy, const size_t size, const size_t reserve_graph_degree);

template <typename T>
DISKANN_DLLEXPORT static std::shared_ptr<AbstractDataStore<T>> construct_datastore(DataStoreStrategy stratagy,
size_t num_points,
size_t dimension, Metric m);
DISKANN_DLLEXPORT static std::shared_ptr<AbstractDataStore<T>> construct_datastore(const IndexConfig& index_config);

template <typename T>
DISKANN_DLLEXPORT static std::shared_ptr<AbstractDataStore<T>> construct_mem_datastore(size_t total_internal_points, size_t dimension,
Metric metric);
// For now PQDataStore incorporates within itself all variants of quantization that we support. In the
// future it may be necessary to introduce an AbstractPQDataStore class to spearate various quantization
// flavours.
Expand Down
1 change: 0 additions & 1 deletion include/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ enum class LabelFormatType :uint8_t
};

class IndexWriteParameters

{
public:
const uint32_t search_list_size; // L
Expand Down
6 changes: 6 additions & 0 deletions include/scratch.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ template <typename T> class InMemQueryScratch : public AbstractScratch<T>
{
return _query_label_bitmask;
}
inline std::vector<Neighbor>& reranked_results()
{
return _reranked_results;
}

private:
uint32_t _L;
Expand Down Expand Up @@ -149,6 +153,8 @@ template <typename T> class InMemQueryScratch : public AbstractScratch<T>
std::vector<uint32_t> _occlude_list_output;
// bitmask buffer in searching time
std::vector<std::uint64_t> _query_label_bitmask;
// Buffer for reranking results to avoid repeated allocations
std::vector<Neighbor> _reranked_results;
};

//
Expand Down
Loading
Loading