Skip to content

Commit 8dba827

Browse files
fix: fixed boruvka mst algorithm (#132)
1 parent 8255c6e commit 8dba827

File tree

2 files changed

+105
-23
lines changed

2 files changed

+105
-23
lines changed

cpl/inc/tree.h

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -426,39 +426,40 @@ OutIt dijkstra(const std::size_t n, InIt first, InIt last, OutIt dest, const std
426426
return dest;
427427
}
428428

429-
template <class InIt, class OutIt, class Pr1 = EdgeLess, class Pr2 = EdgeLess>
430-
OutIt boruvka(const std::size_t n, InIt first, InIt last, OutIt dest, Pr1 preferred = Pr1{}, Pr2 tie_break = Pr2{}) {
429+
template <class InIt, class OutIt, class Pred = EdgeLess>
430+
OutIt boruvka(const std::size_t n, InIt first, InIt last, OutIt dest, Pred pred = Pred{}) {
431+
using edge_ptr = decltype(&*first);
432+
431433
DisjointSet<std::size_t> ds(n);
432-
std::vector<std::size_t> cheapest(n, std::numeric_limits<std::size_t>::max());
433-
std::vector<std::size_t> cheapest_edge(n, std::numeric_limits<std::size_t>::max());
434+
std::vector<edge_ptr> cheapest(n);
434435

435-
std::size_t mst_size = 0;
436-
while (mst_size < n - 1) {
437-
for (std::size_t i = 0; i < n; ++i) {
438-
cheapest[i] = std::numeric_limits<std::size_t>::max();
439-
cheapest_edge[i] = std::numeric_limits<std::size_t>::max();
440-
}
436+
for (std::size_t components = n; components > 1;) {
437+
std::fill(cheapest.begin(), cheapest.end(), nullptr);
441438

442439
for (auto it = first; it != last; ++it) {
443440
auto set1 = ds.find(it->from);
444441
auto set2 = ds.find(it->to);
445-
446-
if (set1 == set2) {
447-
continue;
448-
}
449-
450-
if (preferred(*it, cheapest[set1]) || (tie_break(*it, cheapest[set1]) && it->weight == cheapest[set1])) {
451-
cheapest[set1] = it->weight;
452-
cheapest_edge[set1] = it->to;
442+
if (set1 != set2) {
443+
if (!cheapest[set1] || pred(*it, *cheapest[set1])) {
444+
cheapest[set1] = &*it;
445+
}
446+
447+
if (!cheapest[set2] || pred(*it, *cheapest[set2])) {
448+
cheapest[set2] = &*it;
449+
}
453450
}
454451
}
455452

456453
for (std::size_t i = 0; i < n; ++i) {
457-
if (cheapest[i] != std::numeric_limits<std::size_t>::max()) {
458-
ds.union_rank(cheapest_edge[i], i);
459-
*dest = Edge{cheapest_edge[i], i, cheapest[i]};
460-
++mst_size;
461-
++dest;
454+
if (!cheapest[i]) {
455+
continue;
456+
}
457+
458+
auto set1 = ds.find(cheapest[i]->from), set2 = ds.find(cheapest[i]->to);
459+
if (set1 != set2) {
460+
ds.union_rank(set1, set2);
461+
*dest++ = *cheapest[i];
462+
--components;
462463
}
463464
}
464465
}

tests/cpl/boruvka_mst/test.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright (c) Brandon Pacewic
2+
// SPDX-License-Identifier: MIT
3+
4+
#include <cassert>
5+
#include <vector>
6+
7+
#include "minimum_spanning_tree_test_cases.hpp"
8+
#include "tree.h"
9+
10+
template <class EdgeContainer>
11+
auto total_weight(const EdgeContainer& edges) {
12+
using weight_type = decltype(edges[0].weight);
13+
weight_type sum = 0;
14+
for (const auto& e : edges) {
15+
sum += e.weight;
16+
}
17+
return sum;
18+
}
19+
20+
int main() {
21+
using namespace std;
22+
using namespace cpl;
23+
24+
{
25+
auto [input, expected] = small_test_case();
26+
vector<Edge<>> mst;
27+
boruvka(9, input.begin(), input.end(), back_inserter(mst));
28+
29+
assert(mst.size() == 8);
30+
31+
auto expected_weight = total_weight(expected);
32+
auto actual_weight = total_weight(mst);
33+
assert(actual_weight == expected_weight);
34+
}
35+
{
36+
auto [input, expected] = single_edge_test_case();
37+
vector<Edge<>> mst;
38+
boruvka(2, input.begin(), input.end(), back_inserter(mst));
39+
40+
assert(mst.size() == 1);
41+
assert(total_weight(mst) == total_weight(expected));
42+
}
43+
{
44+
vector<Edge<>> input = {
45+
{0, 1, 1},
46+
{1, 2, 2},
47+
{0, 2, 3},
48+
};
49+
vector<Edge<>> mst;
50+
boruvka(3, input.begin(), input.end(), back_inserter(mst));
51+
52+
assert(mst.size() == 2);
53+
assert(total_weight(mst) == 3);
54+
}
55+
{
56+
auto [input, expected] = same_weight_test_case();
57+
vector<Edge<>> mst;
58+
boruvka(4, input.begin(), input.end(), back_inserter(mst));
59+
60+
assert(mst.size() == 3);
61+
assert(total_weight(mst) == total_weight(expected));
62+
}
63+
{
64+
auto [input, expected] = large_test_case();
65+
vector<Edge<>> mst;
66+
boruvka(100, input.begin(), input.end(), back_inserter(mst));
67+
68+
assert(mst.size() == 99);
69+
assert(total_weight(mst) == total_weight(expected));
70+
}
71+
{
72+
auto [input, expected] = large_sparse_test_case();
73+
vector<Edge<>> mst;
74+
boruvka(1000, input.begin(), input.end(), back_inserter(mst));
75+
76+
assert(mst.size() == 999);
77+
assert(total_weight(mst) == total_weight(expected));
78+
}
79+
80+
return 0;
81+
}

0 commit comments

Comments
 (0)