Skip to content

Commit 0ac362a

Browse files
authored
Use small_vector (#30)
* fix * refactor * add small_vector * fix * update
1 parent ace767e commit 0ac362a

File tree

3 files changed

+1240
-85
lines changed

3 files changed

+1240
-85
lines changed

cpp/prtree.h

Lines changed: 78 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
#include <mutex>
1414
#include <numeric>
1515
#include <queue>
16-
#include <random>
1716
#include <stack>
1817
#include <thread>
1918
#include <unordered_map>
2019
#include <vector>
2120
#include <string>
2221
#include <utility>
2322
#include <optional>
23+
#include <functional>
2424

2525
#include <pybind11/numpy.h>
2626
#include <pybind11/pybind11.h>
@@ -38,6 +38,7 @@
3838

3939
#include <snappy.h>
4040
#include "parallel.h"
41+
#include "small_vector.h"
4142

4243
#ifdef MY_DEBUG
4344
#include <gperftools/profiler.h>
@@ -50,13 +51,15 @@ namespace py = pybind11;
5051
template <class T>
5152
using vec = std::vector<T>;
5253

54+
template <class T, size_t StaticCapacity>
55+
using svec = itlib::small_vector<T, StaticCapacity>;
56+
5357
template <class T>
5458
using deque = std::deque<T>;
5559

5660
template <class T>
5761
using queue = std::queue<T, deque<T>>;
5862

59-
static std::mt19937 rand_src(42);
6063
static const float REBUILD_THRE = 1.25;
6164

6265
#if defined(__GNUC__) || defined(__clang__)
@@ -81,14 +84,6 @@ std::string decompress(std::string &data)
8184
return output;
8285
}
8386

84-
template <typename Iter>
85-
Iter select_randomly(Iter start, Iter end)
86-
{
87-
std::uniform_int_distribution<> dis(0, std::distance(start, end) - 1);
88-
std::advance(start, dis(rand_src));
89-
return start;
90-
}
91-
9287
template <int D = 2>
9388
class BB
9489
{
@@ -200,7 +195,8 @@ class BB
200195
{
201196
Real m = std::min(values[i], target.values[i]);
202197
Real M = std::min(values[i + D], target.values[i + D]);
203-
if (-m > M){
198+
if (-m > M)
199+
{
204200
return false;
205201
}
206202
}
@@ -264,18 +260,16 @@ class Leaf
264260
int axis = 0;
265261
Real min_val = 1e100;
266262
BB<D> mbb;
267-
vec<DataType<T, D>> data; // You can swap when filtering
263+
svec<DataType<T, D>, B> data; // You can swap when filtering
268264
// T is type of keys(ids) which will be returned when you post a query.
269265
Leaf()
270266
{
271267
mbb = BB<D>();
272-
data.reserve(B);
273268
}
274269
Leaf(const int _axis)
275270
{
276271
axis = _axis;
277272
mbb = BB<D>();
278-
data.reserve(B);
279273
}
280274

281275
Real area() const
@@ -284,7 +278,26 @@ class Leaf
284278
}
285279

286280
template <class Archive>
287-
void serialize(Archive &ar) { ar(axis, min_val, mbb, data); }
281+
void save(Archive &ar) const
282+
{
283+
vec<DataType<T, D>> _data;
284+
for (const auto &datum : data)
285+
{
286+
_data.push_back(datum);
287+
}
288+
ar(axis, min_val, mbb, _data);
289+
}
290+
291+
template <class Archive>
292+
void load(Archive &ar)
293+
{
294+
vec<DataType<T, D>> _data;
295+
ar(axis, min_val, mbb, _data);
296+
for (const auto &datum : _data)
297+
{
298+
data.push_back(datum);
299+
}
300+
}
288301

289302
void set_axis(const int &_axis) { axis = _axis; }
290303

@@ -323,7 +336,7 @@ class Leaf
323336
}
324337
else
325338
{ // if there is no room, check the priority and swap if needed
326-
/*
339+
/*
327340
auto iter = std::upper_bound(data.begin(), data.end(), value, [&](const auto &a, const auto &b) noexcept
328341
{ return a.second[axis] < b.second[axis]; });
329342
if (iter != data.end())
@@ -580,6 +593,46 @@ class PRTreeNode
580593
}
581594
};
582595

596+
template <class T, int B = 6, int D = 2>
597+
void bfs(const std::function<void(std::unique_ptr<Leaf<T, B, D>> &)> &func, PRTreeNode<T, B, D> *root, const BB<D> target)
598+
{
599+
queue<PRTreeNode<T, B, D> *> que;
600+
PRTreeNode<T, B, D> *p, *q;
601+
auto qpush_if_intersect = [&](PRTreeNode<T, B, D> *r)
602+
{
603+
if ((*r)(target))
604+
{
605+
que.emplace(r);
606+
}
607+
};
608+
609+
p = root;
610+
qpush_if_intersect(p);
611+
while (!que.empty())
612+
{
613+
p = que.front();
614+
que.pop();
615+
616+
if (p->leaf)
617+
{
618+
func(p->leaf);
619+
}
620+
else
621+
{
622+
if (p->head)
623+
{
624+
q = p->head.get();
625+
qpush_if_intersect(q);
626+
while (q->next)
627+
{
628+
q = q->next.get();
629+
qpush_if_intersect(q);
630+
}
631+
}
632+
}
633+
}
634+
}
635+
583636
template <class T, int B = 6, int D = 2>
584637
class PRTree
585638
{
@@ -1076,13 +1129,11 @@ class PRTree
10761129
#ifdef MY_DEBUG
10771130
std::for_each(X.begin(), X.end(),
10781131
[&](const BB<D> &x)
1079-
{
1080-
out.push_back(find(x)); });
1132+
{ out.push_back(find(x)); });
10811133
#else
10821134
parallel_for_each(X.begin(), X.end(), out,
10831135
[&](const BB<D> &x, auto &o)
1084-
{
1085-
o.push_back(find(x)); });
1136+
{ o.push_back(find(x)); });
10861137
#endif
10871138
#ifdef MY_DEBUG
10881139
ProfilerStop();
@@ -1124,41 +1175,12 @@ class PRTree
11241175
vec<T> find(const BB<D> &target)
11251176
{
11261177
vec<T> out;
1127-
queue<PRTreeNode<T, B, D> *> que;
1128-
PRTreeNode<T, B, D> *p, *q;
1129-
auto qpush_if_intersect = [&](PRTreeNode<T, B, D> *r)
1178+
auto func = [&](std::unique_ptr<Leaf<T, B, D>> &leaf)
11301179
{
1131-
if ((*r)(target))
1132-
{
1133-
que.emplace(r);
1134-
}
1180+
(*leaf)(target, out);
11351181
};
11361182

1137-
p = root.get();
1138-
qpush_if_intersect(p);
1139-
while (!que.empty())
1140-
{
1141-
p = que.front();
1142-
que.pop();
1143-
1144-
if (p->leaf)
1145-
{
1146-
(*p->leaf)(target, out);
1147-
}
1148-
else
1149-
{
1150-
if (p->head)
1151-
{
1152-
q = p->head.get();
1153-
qpush_if_intersect(q);
1154-
while (q->next)
1155-
{
1156-
q = q->next.get();
1157-
qpush_if_intersect(q);
1158-
}
1159-
}
1160-
}
1161-
}
1183+
bfs<T, B, D>(func, root.get(), target);
11621184
return out;
11631185
}
11641186

@@ -1170,41 +1192,13 @@ class PRTree
11701192
throw std::runtime_error("Given index is not found.");
11711193
}
11721194
BB<D> target = it->second;
1173-
queue<PRTreeNode<T, B, D> *> que;
1174-
PRTreeNode<T, B, D> *p, *q;
1175-
auto qpush_if_intersect = [&](PRTreeNode<T, B, D> *r)
1176-
{
1177-
if ((*r)(target))
1178-
{
1179-
que.emplace(r);
1180-
}
1181-
};
11821195

1183-
p = root.get();
1184-
qpush_if_intersect(p);
1185-
while (!que.empty())
1196+
auto func = [&](std::unique_ptr<Leaf<T, B, D>> &leaf)
11861197
{
1187-
p = que.front();
1188-
que.pop();
1198+
leaf->del(idx, target);
1199+
};
11891200

1190-
if (p->leaf)
1191-
{
1192-
p->leaf->del(idx, target);
1193-
}
1194-
else
1195-
{
1196-
if (p->head)
1197-
{
1198-
q = p->head.get();
1199-
qpush_if_intersect(q);
1200-
while (q->next)
1201-
{
1202-
q = q->next.get();
1203-
qpush_if_intersect(q);
1204-
}
1205-
}
1206-
}
1207-
}
1201+
bfs<T, B, D>(func, root.get(), target);
12081202

12091203
idx2bb.erase(idx);
12101204
idx2data.erase(idx);

0 commit comments

Comments
 (0)