@@ -19,18 +19,21 @@ std::vector<float> gen_vec(size_t dim, std::mt19937& rng) {
1919 std::uniform_real_distribution<float > dist (-1 .0f , 1 .0f );
2020 std::vector<float > vec (dim);
2121 float norm = 0 .0f ;
22- for (auto & v : vec) { v = dist (rng); norm += v * v; }
22+ for (auto & v : vec) {
23+ v = dist (rng);
24+ norm += v * v;
25+ }
2326 norm = std::sqrt (norm);
24- for (auto & v : vec) { v /= norm; }
27+ for (auto & v : vec) {
28+ v /= norm;
29+ }
2530 return vec;
2631}
2732
2833// Compute ground truth top-k for a single query (brute force)
29- std::unordered_set<size_t > compute_gt (
30- const std::vector<float >& query,
31- const std::vector<std::vector<float >>& corpus,
32- size_t k, const CosineMetric<float >& metric
33- ) {
34+ std::unordered_set<size_t > compute_gt (const std::vector<float >& query,
35+ const std::vector<std::vector<float >>& corpus, size_t k,
36+ const CosineMetric<float >& metric) {
3437 std::vector<std::pair<size_t , float >> dists;
3538 dists.reserve (corpus.size ());
3639 for (size_t i = 0 ; i < corpus.size (); ++i) {
@@ -40,7 +43,8 @@ std::unordered_set<size_t> compute_gt(
4043 std::partial_sort (dists.begin (), dists.begin () + k, dists.end (),
4144 [](auto & a, auto & b) { return a.second < b.second ; });
4245 std::unordered_set<size_t > result;
43- for (size_t i = 0 ; i < k; ++i) result.insert (dists[i].first );
46+ for (size_t i = 0 ; i < k; ++i)
47+ result.insert (dists[i].first );
4448 return result;
4549}
4650
@@ -51,24 +55,29 @@ int main() {
5155
5256 printf (" =================================================================\n " );
5357 printf (" 768d Recall Benchmark: M=24 vs M=32\n " );
54- printf (" Corpus: %zu, Dim: %zu, k: %zu, Queries: %zu, Threads: %zu\n " ,
55- corpus_size, dim, k, num_queries, num_threads);
58+ printf (" Corpus: %zu, Dim: %zu, k: %zu, Queries: %zu, Threads: %zu\n " , corpus_size, dim, k,
59+ num_queries, num_threads);
5660 printf (" =================================================================\n\n " );
5761 fflush (stdout);
5862
5963 // Generate corpus
60- printf (" Generating corpus...\n " ); fflush (stdout);
64+ printf (" Generating corpus...\n " );
65+ fflush (stdout);
6166 std::vector<std::vector<float >> vecs;
6267 vecs.reserve (corpus_size);
63- for (size_t i = 0 ; i < corpus_size; ++i) vecs.push_back (gen_vec (dim, rng));
68+ for (size_t i = 0 ; i < corpus_size; ++i)
69+ vecs.push_back (gen_vec (dim, rng));
6470
6571 // Generate queries
66- printf (" Generating queries...\n " ); fflush (stdout);
72+ printf (" Generating queries...\n " );
73+ fflush (stdout);
6774 std::vector<std::vector<float >> qvecs;
68- for (size_t i = 0 ; i < num_queries; ++i) qvecs.push_back (gen_vec (dim, rng));
75+ for (size_t i = 0 ; i < num_queries; ++i)
76+ qvecs.push_back (gen_vec (dim, rng));
6977
7078 // Parallel ground truth computation
71- printf (" Computing ground truth (parallel, %zu threads)...\n " , num_threads); fflush (stdout);
79+ printf (" Computing ground truth (parallel, %zu threads)...\n " , num_threads);
80+ fflush (stdout);
7281 auto gt_start = std::chrono::high_resolution_clock::now ();
7382
7483 std::vector<std::unordered_set<size_t >> gt_sets (num_queries);
@@ -85,42 +94,49 @@ int main() {
8594 for (size_t t = 0 ; t < num_threads; ++t) {
8695 size_t start = t * chunk;
8796 size_t end = std::min (start + chunk, num_queries);
88- if (start < end) threads.emplace_back (worker, start, end);
97+ if (start < end)
98+ threads.emplace_back (worker, start, end);
8999 }
90- for (auto & t : threads) t.join ();
100+ for (auto & t : threads)
101+ t.join ();
91102
92103 auto gt_end = std::chrono::high_resolution_clock::now ();
93- printf (" Ground truth: %.1fs\n\n " ,
94- std::chrono::duration<double >(gt_end - gt_start).count ());
104+ printf (" Ground truth: %.1fs\n\n " , std::chrono::duration<double >(gt_end - gt_start).count ());
95105 fflush (stdout);
96106
97107 printf (" %-6s | %-10s | %-12s | %-12s\n " , " M" , " ef_search" , " Recall@10" , " Latency(us)" );
98108 printf (" -------|------------|--------------|-------------\n " );
99109 fflush (stdout);
100110
101- for (auto [M, M_max, M_max_0] : {std::tuple{24 ,48 ,96 }, std::tuple{32 ,64 ,128 }}) {
111+ for (auto [M, M_max, M_max_0] : {std::tuple{24 , 48 , 96 }, std::tuple{32 , 64 , 128 }}) {
102112 // Build index
103113 HNSWIndex<float , CosineMetric<float >>::Config cfg;
104- cfg.M = M; cfg.M_max = M_max; cfg.M_max_0 = M_max_0; cfg.ef_construction = 200 ;
114+ cfg.M = M;
115+ cfg.M_max = M_max;
116+ cfg.M_max_0 = M_max_0;
117+ cfg.ef_construction = 200 ;
105118 HNSWIndex<float , CosineMetric<float >> idx (cfg);
106119
107120 auto build_start = std::chrono::high_resolution_clock::now ();
108121 for (size_t i = 0 ; i < corpus_size; ++i)
109122 idx.insert (i, std::span<const float >{vecs[i]});
110123 auto build_end = std::chrono::high_resolution_clock::now ();
111124 double build_s = std::chrono::duration<double >(build_end - build_start).count ();
112- printf (" M=%d built in %.1fs\n " , M, build_s); fflush (stdout);
125+ printf (" M=%d built in %.1fs\n " , M, build_s);
126+ fflush (stdout);
113127
114128 for (size_t ef : {50UL , 100UL , 200UL }) {
115129 size_t hits = 0 ;
116130 auto start = std::chrono::high_resolution_clock::now ();
117131 for (size_t q = 0 ; q < num_queries; ++q) {
118132 auto results = idx.search (std::span<const float >{qvecs[q]}, k, ef);
119133 for (const auto & [id, _] : results)
120- if (gt_sets[q].count (id)) ++hits;
134+ if (gt_sets[q].count (id))
135+ ++hits;
121136 }
122137 auto end = std::chrono::high_resolution_clock::now ();
123- double lat_us = std::chrono::duration<double , std::micro>(end - start).count () / num_queries;
138+ double lat_us =
139+ std::chrono::duration<double , std::micro>(end - start).count () / num_queries;
124140 double recall = 100.0 * hits / (num_queries * k);
125141 printf (" %-6d | %-10zu | %10.1f%% | %10.1f\n " , M, ef, recall, lat_us);
126142 fflush (stdout);
@@ -129,6 +145,7 @@ int main() {
129145 fflush (stdout);
130146 }
131147
132- printf (" \n Conclusion: If M=32 recall > M=24 by >3%%, consider extending Config::for_corpus()\n " );
148+ printf (
149+ " \n Conclusion: If M=32 recall > M=24 by >3%%, consider extending Config::for_corpus()\n " );
133150 return 0 ;
134151}
0 commit comments