Skip to content

Commit 828c0ed

Browse files
committed
bug fixes and additional optimizations
1 parent c69d1ed commit 828c0ed

2 files changed

Lines changed: 113 additions & 1 deletion

File tree

include/sqlite-vec-cpp/index/hnsw.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,8 @@ template <concepts::VectorElement StorageT, typename MetricT> class HNSWIndex {
640640
}
641641

642642
// Upper layer stats
643-
for (size_t layer = 1; layer <= node.max_layer(); ++layer) {
643+
size_t node_max_layer = node.num_layers() > 0 ? node.num_layers() - 1 : 0;
644+
for (size_t layer = 1; layer <= node_max_layer; ++layer) {
644645
size_t degree = node.neighbors(layer).size();
645646
total_degree_upper += degree;
646647
stats.total_edges += degree;

tests/test_hnsw.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,114 @@ void test_fp16_accuracy() {
763763
std::cout << " ✓ fp16 accuracy passed" << std::endl;
764764
}
765765

766+
// Test 19: Graph quality metrics
767+
void test_graph_stats() {
768+
std::cout << "Test 19: Graph quality metrics..." << std::endl;
769+
770+
constexpr size_t num_vectors = 1000;
771+
constexpr size_t dim = 64;
772+
std::mt19937 rng(42);
773+
774+
HNSWIndex<float, L2Metric<float>> index;
775+
for (size_t i = 0; i < num_vectors; ++i) {
776+
auto vec = generate_vector(dim, rng);
777+
index.insert(i, std::span{vec});
778+
}
779+
780+
auto stats = index.compute_graph_stats();
781+
782+
std::cout << " Nodes: " << stats.num_nodes << std::endl;
783+
std::cout << " Layers: " << stats.num_layers << std::endl;
784+
std::cout << " Total edges: " << stats.total_edges << std::endl;
785+
std::cout << " Avg degree (layer 0): " << stats.avg_degree_layer0 << std::endl;
786+
std::cout << " Min/Max degree (layer 0): " << stats.min_degree_layer0 << "/"
787+
<< stats.max_degree_layer0 << std::endl;
788+
std::cout << " Orphan nodes: " << stats.orphan_count << std::endl;
789+
std::cout << " Connectivity score: " << (stats.connectivity_score * 100) << "%" << std::endl;
790+
791+
// Assertions
792+
assert(stats.num_nodes == num_vectors);
793+
assert(stats.num_layers >= 1);
794+
assert(stats.orphan_count == 0); // No orphans for healthy graph
795+
assert(stats.avg_degree_layer0 >= 4.0); // Should have reasonable connectivity
796+
assert(stats.is_healthy()); // Graph should be healthy
797+
798+
std::cout << " ✓ Graph stats passed (healthy=" << (stats.is_healthy() ? "yes" : "no") << ")"
799+
<< std::endl;
800+
}
801+
802+
// Test 20: Adaptive ef_search
803+
void test_adaptive_search() {
804+
std::cout << "Test 20: Adaptive ef_search..." << std::endl;
805+
806+
constexpr size_t num_vectors = 5000;
807+
constexpr size_t dim = 128;
808+
constexpr size_t k = 10;
809+
std::mt19937 rng(42);
810+
811+
HNSWIndex<float, L2Metric<float>> index;
812+
std::vector<std::vector<float>> vectors;
813+
vectors.reserve(num_vectors);
814+
815+
for (size_t i = 0; i < num_vectors; ++i) {
816+
vectors.push_back(generate_vector(dim, rng));
817+
index.insert(i, std::span{vectors[i]});
818+
}
819+
820+
// Check recommended ef_search values
821+
size_t ef_90 = index.recommended_ef_search(k, 0.90f);
822+
size_t ef_95 = index.recommended_ef_search(k, 0.95f);
823+
size_t ef_99 = index.recommended_ef_search(k, 0.99f);
824+
825+
std::cout << " Corpus size: " << num_vectors << std::endl;
826+
std::cout << " ef_search for 90% recall: " << ef_90 << std::endl;
827+
std::cout << " ef_search for 95% recall: " << ef_95 << std::endl;
828+
std::cout << " ef_search for 99% recall: " << ef_99 << std::endl;
829+
830+
// Higher target recall should require higher ef_search
831+
assert(ef_95 >= ef_90);
832+
assert(ef_99 >= ef_95);
833+
assert(ef_90 >= k); // ef_search should always be >= k
834+
835+
// Test adaptive search returns results
836+
auto query = generate_vector(dim, rng);
837+
auto results = index.search_adaptive(std::span{query}, k, 0.95f);
838+
assert(results.size() == k);
839+
840+
std::cout << " ✓ Adaptive search passed" << std::endl;
841+
}
842+
843+
// Test 21: Config::for_corpus factory
844+
void test_config_for_corpus() {
845+
std::cout << "Test 21: Config::for_corpus factory..." << std::endl;
846+
847+
using Config = HNSWIndex<float, L2Metric<float>>::Config;
848+
849+
// Small corpus, low dim
850+
auto cfg_small = Config::for_corpus(1000, 64);
851+
std::cout << " 1K vectors, dim=64: M=" << cfg_small.M << ", M_max_0=" << cfg_small.M_max_0
852+
<< ", ef_construction=" << cfg_small.ef_construction << std::endl;
853+
854+
// Medium corpus, medium dim
855+
auto cfg_med = Config::for_corpus(50000, 128);
856+
std::cout << " 50K vectors, dim=128: M=" << cfg_med.M << ", M_max_0=" << cfg_med.M_max_0
857+
<< ", ef_construction=" << cfg_med.ef_construction << std::endl;
858+
859+
// Large corpus, high dim
860+
auto cfg_large = Config::for_corpus(500000, 384);
861+
std::cout << " 500K vectors, dim=384: M=" << cfg_large.M << ", M_max_0=" << cfg_large.M_max_0
862+
<< ", ef_construction=" << cfg_large.ef_construction << std::endl;
863+
864+
// Larger corpus should have higher ef_construction
865+
assert(cfg_large.ef_construction >= cfg_med.ef_construction);
866+
assert(cfg_med.ef_construction >= cfg_small.ef_construction);
867+
868+
// Higher dim should have higher M
869+
assert(cfg_large.M >= cfg_med.M);
870+
871+
std::cout << " ✓ Config factory passed" << std::endl;
872+
}
873+
766874
int main() {
767875
std::cout << "Running HNSW tests...\n" << std::endl;
768876

@@ -785,6 +893,9 @@ int main() {
785893
test_parallel_build();
786894
test_fp16_storage();
787895
test_fp16_accuracy();
896+
test_graph_stats();
897+
test_adaptive_search();
898+
test_config_for_corpus();
788899

789900
std::cout << "\nAll HNSW tests passed!" << std::endl;
790901
return 0;

0 commit comments

Comments
 (0)