Skip to content

Commit f23aec1

Browse files
committed
位于分支 master
您的分支与上游分支 'origin/master' 一致。 要提交的变更: 修改: main.tex 修改: nearest_point_pair.cpp 修改: str/suffix_array.cpp
1 parent 6cb635e commit f23aec1

File tree

3 files changed

+105
-82
lines changed

3 files changed

+105
-82
lines changed

main.tex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ \chapter{计算几何}
246246

247247
\lstinputlisting[language=C++, caption=geometry.cpp, style=MyCStyle]{./geometry.cpp}
248248

249-
\section{最近点对}
249+
\section*{最近点对}
250250

251251
在平面中, 我们经常需要找到两个最近的点. 相比最传统的暴力方法, 即找到所有点对, 并计算出任意两者之间的距离, 一些方法可以大幅度简化计算, 进而有效降低我们的程序的时间复杂度. 本题中, 我们采用了一种分治与合并的方法, 从概率上, 让算法的复杂度降低到 $O(nlog(n))$.
252252

nearest_point_pair.cpp

Lines changed: 104 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,71 +10,126 @@
1010
* @details
1111
*/
1212

13+
/**
14+
* @file nearest_point_pair.cpp
15+
* @brief
16+
* @author Haoming Bai <haomingbai@hotmail.com>
17+
* @date 2025-09-16
18+
*
19+
* Copyright © 2025 Haoming Bai
20+
* SPDX-License-Identifier: MIT
21+
*
22+
* @details
23+
*/
24+
1325
#include <algorithm>
14-
#include <cmath>
1526
#include <cstddef>
16-
#include <cstdint>
17-
#include <cstdlib>
27+
#include <functional>
28+
#include <limits>
1829
#include <span>
30+
#include <vector>
31+
32+
#include "./concepts.cpp"
1933

34+
template <Multiplyable T>
2035
struct Point2D {
21-
double x, y;
36+
double x, y; // 保持你原来的设计不变(模板参数存在但成员为 double)
2237
};
2338

24-
double DistanceWith(const Point2D &a, const Point2D &b) {
25-
return std::sqrt((a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y));
39+
// 保持签名但改为用 T 做整数运算(将坐标 cast 为 T)
40+
// 假设你已经声明输入为整数且平方不会溢出(按你要求)
41+
template <Multiplyable T>
42+
T DistanceSquareWith(const Point2D<T>& a, const Point2D<T>& b) {
43+
T dx = static_cast<T>(a.x) - static_cast<T>(b.x);
44+
T dy = static_cast<T>(a.y) - static_cast<T>(b.y);
45+
return dx * dx + dy * dy;
2646
}
2747

28-
double FindNearestDistance(std::span<Point2D> point_list) {
29-
if (point_list.size() == 2) {
30-
return DistanceWith(point_list.front(), point_list.back());
31-
} else if (point_list.size() < 2) {
32-
return MAXFLOAT;
33-
}
34-
35-
auto left_span = point_list.subspan(0, point_list.size() / 2);
36-
auto right_span = point_list.subspan(point_list.size() / 2);
37-
38-
double distance_left = FindNearestDistance(left_span);
39-
double distance_right = FindNearestDistance(right_span);
40-
41-
auto dist_to_cmp = std::min(distance_left, distance_right);
42-
43-
auto mid_x = (point_list[point_list.size() / 2 - 1].x +
44-
point_list[point_list.size() / 2].x) /
45-
2;
46-
size_t left_edge = SIZE_MAX, right_edge = SIZE_MAX;
47-
size_t left_mid_edge = left_span.size() - 1;
48-
size_t right_mid_edge = point_list.size() / 2;
49-
50-
{
51-
for (ptrdiff_t i = left_span.size() - 1; i >= 0; i--) {
52-
if (std::abs(left_span[i].x - mid_x) < dist_to_cmp) {
53-
left_edge = i;
54-
} else {
55-
break;
48+
// 保持外部签名不变:输入是按 x 排序好的点(你的 main 已经做了 sort)
49+
template <Multiplyable T>
50+
T FindNearestDistanceSquare(std::span<Point2D<T>> point_list) {
51+
const size_t n = point_list.size();
52+
if (n < 2) return std::numeric_limits<T>::max();
53+
if (n == 2) return DistanceSquareWith(point_list.front(), point_list.back());
54+
55+
// 为递归使用,直接操作底层数据指针,避免频繁拷贝 span 对象
56+
Point2D<T>* base = point_list.data();
57+
58+
// 递归 lambda,半开区间 [l, r)
59+
std::function<T(size_t, size_t)> rec;
60+
rec = [&](size_t l, size_t r) -> T {
61+
size_t len = r - l;
62+
if (len < 2) {
63+
return std::numeric_limits<T>::max();
64+
}
65+
if (len == 2) {
66+
// 两点时按原样比较(不改变顺序)
67+
return DistanceSquareWith(base[l], base[l + 1]);
68+
}
69+
70+
size_t m = l + (len >> 1);
71+
// 保存分割线 x 坐标。**注意**:必须在递归之前或基于当前 x 排序的假设确定
72+
// mid_x。 这里我们使用分割点为右半部分首元素的
73+
// x(常见做法,避免浮点平均)。
74+
T mid_x = static_cast<T>(base[m].x);
75+
76+
// 递归求左右最短
77+
T dl = rec(l, m);
78+
T dr = rec(m, r);
79+
T d = dl < dr ? dl : dr;
80+
81+
// 找到横向候选区间:从中间向两边线性扩展直到 dx^2 >= d
82+
// (比起二分查找,这里更简单、分支更少,且通常很快——因为带宽一般较小)
83+
size_t left_edge = m; // inclusive
84+
if (m > l) {
85+
for (ptrdiff_t i = static_cast<ptrdiff_t>(m) - 1;
86+
i >= static_cast<ptrdiff_t>(l); --i) {
87+
T dx = static_cast<T>(base[i].x) - mid_x;
88+
if (dx * dx < d)
89+
left_edge = static_cast<size_t>(i);
90+
else
91+
break;
5692
}
5793
}
5894

59-
for (long i = 0; i < right_span.size(); i++) {
60-
if (std::abs(right_span[i].x - mid_x) < dist_to_cmp) {
61-
right_edge = i + right_mid_edge;
62-
} else {
95+
size_t right_edge = m; // exclusive
96+
for (size_t j = m; j < r; ++j) {
97+
T dx = static_cast<T>(base[j].x) - mid_x;
98+
if (dx * dx < d)
99+
right_edge = j + 1; // j included
100+
else
63101
break;
64-
}
65102
}
66-
}
67103

68-
if (left_edge == SIZE_MAX || right_edge == SIZE_MAX) {
69-
return dist_to_cmp;
70-
}
104+
// 如果没有跨中线的候选点,直接返回
105+
if (left_edge >= right_edge) return d;
106+
107+
// 把候选点放到临时数组并按 y 排序(这样在 strip 内可以早停)
108+
std::vector<Point2D<T>*> strip;
109+
strip.reserve(right_edge - left_edge);
110+
for (size_t idx = left_edge; idx < right_edge; ++idx)
111+
strip.push_back(&base[idx]);
71112

72-
for (auto i = left_edge; i <= left_mid_edge; i++) {
73-
for (auto j = right_mid_edge; j <= right_edge; j++) {
74-
dist_to_cmp =
75-
std::min(dist_to_cmp, DistanceWith(point_list[i], point_list[j]));
113+
std::sort(strip.begin(), strip.end(),
114+
[](const Point2D<T>* a, const Point2D<T>* b) {
115+
if (a->y != b->y) return a->y < b->y;
116+
return a->x < b->x;
117+
});
118+
119+
// 经典 strip 比较:对每个点只检查 y 差绝对值小于 sqrt(d) 的后续点
120+
// 这里比较使用 dy^2 >= d 的早停条件(避免 sqrt)
121+
for (size_t i = 0; i < strip.size(); ++i) {
122+
for (size_t j = i + 1; j < strip.size(); ++j) {
123+
T dy = static_cast<T>(strip[j]->y) - static_cast<T>(strip[i]->y);
124+
if (dy * dy >= d) break; // y 差已足够大,后面不用再看
125+
// 计算完整平方距离
126+
T cur = DistanceSquareWith(*strip[i], *strip[j]);
127+
if (cur < d) d = cur;
128+
}
76129
}
77-
}
78130

79-
return dist_to_cmp;
131+
return d;
132+
};
133+
134+
return rec(0, n);
80135
}

str/suffix_array.cpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ inline void _InducedSort(const std::vector<long> &str,
4242
// 为了方便抄板子, 我这里先把一些变量提取出来.
4343
// 开销应该不大, 因为估计编译器一下就给消掉了.
4444
const long n = str.size(), max_val = buckets.size() - 1;
45-
4645
// 从左向右扫描SA数组,
4746
// 这里的目标是从LMS进行L诱导.
4847
// 然后把L放到每个桶的头部.
@@ -59,7 +58,6 @@ inline void _InducedSort(const std::vector<long> &str,
5958
prev_bucket.left++;
6059
}
6160
}
62-
6361
// 将桶的底部重置, 这里的意思是删除那些LMS.
6462
// 删除掉LMS之后, 再根据之前放入的L的字符,
6563
// 诱导出所有的S的字符的位置.
@@ -71,7 +69,6 @@ inline void _InducedSort(const std::vector<long> &str,
7169
// 也就意味着这个桶满了.
7270
buckets[i].right = prefix_sums[i] - 1;
7371
}
74-
7572
// 从右往左扫描.
7673
// 这次扫描要把S类型的字符放进桶.
7774
for (long i = n - 1; i >= 0; i--) {
@@ -88,7 +85,6 @@ inline void _InducedSort(const std::vector<long> &str,
8885
}
8986
}
9087
}
91-
9288
// str必须是已经被处理好的, 确认了最后的数字是全局唯一最小的哨兵的串.
9389
// max_val可以给的稍微大一点也没关系.
9490
inline std::vector<long> _SAIS(const std::vector<long> &str,
@@ -103,21 +99,17 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
10399
// str[i] < str[i + 1], 记作S,
104100
// str[i] == str[i + 1], type[i] = type[i + 1].
105101
std::vector<Type> type(str.size(), S_TYPE);
106-
107102
// 这里因为建立桶的需求,
108103
// 所以需要统计每个字符在这里都出现了几次.
109104
// 因为我们两次诱导排序,
110105
// 使用的是同一个串, 所以我们就不再排序的过程中扫描这个了哈.
111106
std::vector<long> cnt_occurance(max_val + 1, 0);
112-
113107
// 尾部字符单独统计,
114108
// 因为下面扫描全字符串是从倒数第二个字符开始的.
115109
cnt_occurance[str.back()]++;
116-
117110
// 收集所有LMS的下标.
118111
std::vector<long> lms_incidies;
119112
lms_incidies.reserve(str.size() / 2);
120-
121113
// 逆序遍历字符串, 获取类型.
122114
// 这里逆序遍历的原因是, 如果
123115
// str[i] == str[i + 1], 那么则有:
@@ -140,11 +132,9 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
140132
} else {
141133
assert(false);
142134
}
143-
144135
// 记录出现次数.
145136
cnt_occurance[str[i]]++;
146137
}
147-
148138
// 创建前缀和数组, 为建立桶和诱导排序做准备.
149139
std::vector<long> prefix_sums(max_val + 2);
150140
std::partial_sum(cnt_occurance.begin(), cnt_occurance.end(),
@@ -158,7 +148,6 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
158148
buckets[i].left = prefix_sums[i - 1];
159149
buckets[i].right = prefix_sums[i] - 1;
160150
}
161-
162151
// 放入LMS.
163152
// 这里对于同一个字母, 入桶的顺序应该是倒序的.
164153
// 这个似乎和诱导排序的实现有关系.
@@ -169,19 +158,15 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
169158
auto curr_idx = *it;
170159
auto curr_char = str[curr_idx];
171160
auto &curr_bucket = buckets[curr_char];
172-
173161
// 将对应的下标放入桶中.
174162
SA[curr_bucket.right] = curr_idx;
175163
curr_bucket.right--;
176164
}
177-
178165
// 进行第一次诱导排序.
179166
_InducedSort(str, type, SA, prefix_sums, buckets);
180-
181167
// 创建名字和下标的对应关系.
182168
// 这里用names数组表达对应位置的名字.
183169
std::vector<long> names(str.size(), -1);
184-
185170
// 这两个变量分别记录了下发的名字的数量,
186171
// 和上一个被探测到的LMS的坐标.
187172
long name_cnt = 0;
@@ -203,7 +188,6 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
203188
return false;
204189
}
205190
};
206-
207191
// 这里的这个函数是用来比较两个LMS子串是否相等的.
208192
const auto is_lms_eq = [&](const unsigned long idx1,
209193
const unsigned long idx2) -> bool {
@@ -217,7 +201,6 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
217201
if (idx1 == str.size() - 1 || idx2 == str.size() - 1) {
218202
return false;
219203
}
220-
221204
// 从偏移量为0开始比较
222205
long offset = 0;
223206
do {
@@ -229,18 +212,15 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
229212
if (type[idx1 + offset] != type[idx2 + offset]) {
230213
return false;
231214
}
232-
233215
// 手动更新偏移量
234216
offset++;
235217
// 循环条件: 两个待比较位置都没有来到下一个LMS
236218
} while (!is_lms(idx1 + offset) && !is_lms(idx2 + offset));
237-
238219
// 如果有一个没有到达下一个LMS但是另外一个到达,
239220
// 那么二者一定不相等.
240221
if (!is_lms(idx1 + offset) || !is_lms(idx2 + offset)) {
241222
return false;
242223
}
243-
244224
// 否则还是比较这两个LMS对应的字符.
245225
if (str[idx1 + offset] != str[idx2 + offset]) {
246226
return false;
@@ -250,7 +230,6 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
250230
}
251231
return true;
252232
};
253-
254233
// 这个时候, it是当前正在处理的的lms下标
255234
if (is_lms(it)) {
256235
// 如果上一个LMS存在,
@@ -271,13 +250,11 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
271250
// 那么肯定要分配一个新名字.
272251
name_cnt++;
273252
}
274-
275253
// 将名字下发下去.
276254
names[it] = name_cnt - 1;
277255
last_lms_idx = it;
278256
}
279257
}
280-
281258
// 命名唯一, 无需递归,
282259
// 直接返回.
283260
if (static_cast<unsigned long>(name_cnt) == lms_incidies.size()) {
@@ -296,10 +273,8 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
296273
auto curr_lms_idx = lms_incidies[i];
297274
lms_str[i] = names[curr_lms_idx];
298275
}
299-
300276
// 最大的一个名字是最大的name_cnt - 1.
301277
lms_SA = _SAIS(lms_str, name_cnt - 1);
302-
303278
// 生成一个新的桶并清空SA数组,
304279
// 进行第二次诱导排序.
305280
std::fill(SA.begin(), SA.end(), -1);
@@ -308,7 +283,6 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
308283
buckets[i].left = prefix_sums[i - 1];
309284
buckets[i].right = prefix_sums[i] - 1;
310285
}
311-
312286
// 这里倒序遍历, 同时将遍历到的位置放在桶对应字母的右侧.
313287
// 因此可以保证lms_SA中靠右的下标会优先被处理后放入桶的右侧.
314288
// 因此对于同一个字母, lms_SA中靠右的在桶中也靠右.
@@ -321,19 +295,15 @@ inline std::vector<long> _SAIS(const std::vector<long> &str,
321295
// 存放的那个lms的名字对应的下标,
322296
// 就是lms_incidies[lms_SA[i]]中存放的那个下标.
323297
auto curr_lms_idx = lms_incidies[lms_SA[i]];
324-
325298
// 同样地获取桶
326299
auto curr_char = str[curr_lms_idx];
327300
auto &curr_bucket = buckets[curr_char];
328-
329301
// 将下标放在桶的右侧.
330302
SA[curr_bucket.right] = curr_lms_idx;
331303
curr_bucket.right--;
332304
}
333-
334305
// 第二次诱导排序.
335306
_InducedSort(str, type, SA, prefix_sums, buckets);
336-
337307
return SA;
338308
}
339309
}
@@ -350,7 +320,6 @@ std::vector<unsigned long> BuildSuffixArray(const Container &str) {
350320
if (processed[i] > max_val) max_val = processed[i];
351321
}
352322
processed.back() = 0; // 哨兵,唯一且最小
353-
354323
auto res = _SAIS(processed, max_val);
355324
// res[0] 对应哨兵的位置 (通常是 processed.size() - 1)
356325
std::vector<unsigned long> processed_res(std::next(res.begin()), res.end());
@@ -368,7 +337,6 @@ std::vector<unsigned long> suffix_array(Container &&str) {
368337
if (processed[i] > max_val) max_val = processed[i];
369338
}
370339
processed.back() = 0; // 哨兵,唯一且最小
371-
372340
auto res = _SAIS(processed, max_val);
373341
// res[0] 对应哨兵的位置 (通常是 processed.size() - 1)
374342
std::vector<unsigned long> processed_res(std::next(res.begin()), res.end());

0 commit comments

Comments
 (0)