1313#include < mutex>
1414#include < numeric>
1515#include < queue>
16- #include < random>
1716#include < stack>
1817#include < thread>
1918#include < unordered_map>
2019#include < vector>
2120#include < string>
2221#include < utility>
2322#include < optional>
23+ #include < functional>
2424
2525#include < pybind11/numpy.h>
2626#include < pybind11/pybind11.h>
3838
3939#include < snappy.h>
4040#include " parallel.h"
41+ #include " small_vector.h"
4142
4243#ifdef MY_DEBUG
4344#include < gperftools/profiler.h>
@@ -50,13 +51,15 @@ namespace py = pybind11;
5051template <class T >
5152using vec = std::vector<T>;
5253
54+ template <class T , size_t StaticCapacity>
55+ using svec = itlib::small_vector<T, StaticCapacity>;
56+
5357template <class T >
5458using deque = std::deque<T>;
5559
5660template <class T >
5761using queue = std::queue<T, deque<T>>;
5862
59- static std::mt19937 rand_src (42 );
6063static const float REBUILD_THRE = 1.25 ;
6164
6265#if defined(__GNUC__) || defined(__clang__)
@@ -81,14 +84,6 @@ std::string decompress(std::string &data)
8184 return output;
8285}
8386
84- template <typename Iter>
85- Iter select_randomly (Iter start, Iter end)
86- {
87- std::uniform_int_distribution<> dis (0 , std::distance (start, end) - 1 );
88- std::advance (start, dis (rand_src));
89- return start;
90- }
91-
9287template <int D = 2 >
9388class BB
9489{
@@ -200,7 +195,8 @@ class BB
200195 {
201196 Real m = std::min (values[i], target.values [i]);
202197 Real M = std::min (values[i + D], target.values [i + D]);
203- if (-m > M){
198+ if (-m > M)
199+ {
204200 return false ;
205201 }
206202 }
@@ -264,18 +260,16 @@ class Leaf
264260 int axis = 0 ;
265261 Real min_val = 1e100 ;
266262 BB<D> mbb;
267- vec <DataType<T, D>> data; // You can swap when filtering
263+ svec <DataType<T, D>, B > data; // You can swap when filtering
268264 // T is type of keys(ids) which will be returned when you post a query.
269265 Leaf ()
270266 {
271267 mbb = BB<D>();
272- data.reserve (B);
273268 }
274269 Leaf (const int _axis)
275270 {
276271 axis = _axis;
277272 mbb = BB<D>();
278- data.reserve (B);
279273 }
280274
281275 Real area () const
@@ -284,7 +278,26 @@ class Leaf
284278 }
285279
286280 template <class Archive >
287- void serialize (Archive &ar) { ar (axis, min_val, mbb, data); }
281+ void save (Archive &ar) const
282+ {
283+ vec<DataType<T, D>> _data;
284+ for (const auto &datum : data)
285+ {
286+ _data.push_back (datum);
287+ }
288+ ar (axis, min_val, mbb, _data);
289+ }
290+
291+ template <class Archive >
292+ void load (Archive &ar)
293+ {
294+ vec<DataType<T, D>> _data;
295+ ar (axis, min_val, mbb, _data);
296+ for (const auto &datum : _data)
297+ {
298+ data.push_back (datum);
299+ }
300+ }
288301
289302 void set_axis (const int &_axis) { axis = _axis; }
290303
@@ -323,7 +336,7 @@ class Leaf
323336 }
324337 else
325338 { // if there is no room, check the priority and swap if needed
326- /*
339+ /*
327340 auto iter = std::upper_bound(data.begin(), data.end(), value, [&](const auto &a, const auto &b) noexcept
328341 { return a.second[axis] < b.second[axis]; });
329342 if (iter != data.end())
@@ -580,6 +593,46 @@ class PRTreeNode
580593 }
581594};
582595
596+ template <class T , int B = 6 , int D = 2 >
597+ void bfs (const std::function<void (std::unique_ptr<Leaf<T, B, D>> &)> &func, PRTreeNode<T, B, D> *root, const BB<D> target)
598+ {
599+ queue<PRTreeNode<T, B, D> *> que;
600+ PRTreeNode<T, B, D> *p, *q;
601+ auto qpush_if_intersect = [&](PRTreeNode<T, B, D> *r)
602+ {
603+ if ((*r)(target))
604+ {
605+ que.emplace (r);
606+ }
607+ };
608+
609+ p = root;
610+ qpush_if_intersect (p);
611+ while (!que.empty ())
612+ {
613+ p = que.front ();
614+ que.pop ();
615+
616+ if (p->leaf )
617+ {
618+ func (p->leaf );
619+ }
620+ else
621+ {
622+ if (p->head )
623+ {
624+ q = p->head .get ();
625+ qpush_if_intersect (q);
626+ while (q->next )
627+ {
628+ q = q->next .get ();
629+ qpush_if_intersect (q);
630+ }
631+ }
632+ }
633+ }
634+ }
635+
583636template <class T , int B = 6 , int D = 2 >
584637class PRTree
585638{
@@ -1076,13 +1129,11 @@ class PRTree
10761129#ifdef MY_DEBUG
10771130 std::for_each (X.begin (), X.end (),
10781131 [&](const BB<D> &x)
1079- {
1080- out.push_back (find (x)); });
1132+ { out.push_back (find (x)); });
10811133#else
10821134 parallel_for_each (X.begin (), X.end (), out,
10831135 [&](const BB<D> &x, auto &o)
1084- {
1085- o.push_back (find (x)); });
1136+ { o.push_back (find (x)); });
10861137#endif
10871138#ifdef MY_DEBUG
10881139 ProfilerStop ();
@@ -1124,41 +1175,12 @@ class PRTree
11241175 vec<T> find (const BB<D> &target)
11251176 {
11261177 vec<T> out;
1127- queue<PRTreeNode<T, B, D> *> que;
1128- PRTreeNode<T, B, D> *p, *q;
1129- auto qpush_if_intersect = [&](PRTreeNode<T, B, D> *r)
1178+ auto func = [&](std::unique_ptr<Leaf<T, B, D>> &leaf)
11301179 {
1131- if ((*r)(target))
1132- {
1133- que.emplace (r);
1134- }
1180+ (*leaf)(target, out);
11351181 };
11361182
1137- p = root.get ();
1138- qpush_if_intersect (p);
1139- while (!que.empty ())
1140- {
1141- p = que.front ();
1142- que.pop ();
1143-
1144- if (p->leaf )
1145- {
1146- (*p->leaf )(target, out);
1147- }
1148- else
1149- {
1150- if (p->head )
1151- {
1152- q = p->head .get ();
1153- qpush_if_intersect (q);
1154- while (q->next )
1155- {
1156- q = q->next .get ();
1157- qpush_if_intersect (q);
1158- }
1159- }
1160- }
1161- }
1183+ bfs<T, B, D>(func, root.get (), target);
11621184 return out;
11631185 }
11641186
@@ -1170,41 +1192,13 @@ class PRTree
11701192 throw std::runtime_error (" Given index is not found." );
11711193 }
11721194 BB<D> target = it->second ;
1173- queue<PRTreeNode<T, B, D> *> que;
1174- PRTreeNode<T, B, D> *p, *q;
1175- auto qpush_if_intersect = [&](PRTreeNode<T, B, D> *r)
1176- {
1177- if ((*r)(target))
1178- {
1179- que.emplace (r);
1180- }
1181- };
11821195
1183- p = root.get ();
1184- qpush_if_intersect (p);
1185- while (!que.empty ())
1196+ auto func = [&](std::unique_ptr<Leaf<T, B, D>> &leaf)
11861197 {
1187- p = que. front ( );
1188- que. pop () ;
1198+ leaf-> del (idx, target );
1199+ } ;
11891200
1190- if (p->leaf )
1191- {
1192- p->leaf ->del (idx, target);
1193- }
1194- else
1195- {
1196- if (p->head )
1197- {
1198- q = p->head .get ();
1199- qpush_if_intersect (q);
1200- while (q->next )
1201- {
1202- q = q->next .get ();
1203- qpush_if_intersect (q);
1204- }
1205- }
1206- }
1207- }
1201+ bfs<T, B, D>(func, root.get (), target);
12081202
12091203 idx2bb.erase (idx);
12101204 idx2data.erase (idx);
0 commit comments