Skip to content

Commit 1c321e9

Browse files
committed
Support larger trees
1 parent 3799062 commit 1c321e9

4 files changed

Lines changed: 111 additions & 60 deletions

File tree

R/RcppExports.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ COMCLUST <- function(trees) {
55
.Call(`_TreeDist_COMCLUST`, trees)
66
}
77

8-
consensus_info <- function(trees, phylo, p) {
9-
.Call(`_TreeDist_consensus_info`, trees, phylo, p)
10-
}
11-
128
robinson_foulds_all_pairs <- function(tables) {
139
.Call(`_TreeDist_robinson_foulds_all_pairs`, tables)
1410
}
1511

12+
consensus_info <- function(trees, phylo, p) {
13+
.Call(`_TreeDist_consensus_info`, trees, phylo, p)
14+
}
15+
1616
HMI_xptr <- function(ptr1, ptr2) {
1717
.Call(`_TreeDist_HMI_xptr`, ptr1, ptr2)
1818
}

src/RcppExports.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,27 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
1111
#endif
1212

1313
// COMCLUST
14-
int COMCLUST(List trees);
14+
int COMCLUST(const List& trees);
1515
RcppExport SEXP _TreeDist_COMCLUST(SEXP treesSEXP) {
1616
BEGIN_RCPP
1717
Rcpp::RObject rcpp_result_gen;
1818
Rcpp::RNGScope rcpp_rngScope_gen;
19-
Rcpp::traits::input_parameter< List >::type trees(treesSEXP);
19+
Rcpp::traits::input_parameter< const List& >::type trees(treesSEXP);
2020
rcpp_result_gen = Rcpp::wrap(COMCLUST(trees));
2121
return rcpp_result_gen;
2222
END_RCPP
2323
}
24+
// robinson_foulds_all_pairs
25+
IntegerVector robinson_foulds_all_pairs(const List& tables);
26+
RcppExport SEXP _TreeDist_robinson_foulds_all_pairs(SEXP tablesSEXP) {
27+
BEGIN_RCPP
28+
Rcpp::RObject rcpp_result_gen;
29+
Rcpp::RNGScope rcpp_rngScope_gen;
30+
Rcpp::traits::input_parameter< const List& >::type tables(tablesSEXP);
31+
rcpp_result_gen = Rcpp::wrap(robinson_foulds_all_pairs(tables));
32+
return rcpp_result_gen;
33+
END_RCPP
34+
}
2435
// consensus_info
2536
double consensus_info(const List trees, const LogicalVector phylo, const NumericVector p);
2637
RcppExport SEXP _TreeDist_consensus_info(SEXP treesSEXP, SEXP phyloSEXP, SEXP pSEXP) {
@@ -34,17 +45,6 @@ BEGIN_RCPP
3445
return rcpp_result_gen;
3546
END_RCPP
3647
}
37-
// robinson_foulds_all_pairs
38-
IntegerVector robinson_foulds_all_pairs(List tables);
39-
RcppExport SEXP _TreeDist_robinson_foulds_all_pairs(SEXP tablesSEXP) {
40-
BEGIN_RCPP
41-
Rcpp::RObject rcpp_result_gen;
42-
Rcpp::RNGScope rcpp_rngScope_gen;
43-
Rcpp::traits::input_parameter< List >::type tables(tablesSEXP);
44-
rcpp_result_gen = Rcpp::wrap(robinson_foulds_all_pairs(tables));
45-
return rcpp_result_gen;
46-
END_RCPP
47-
}
4848
// HMI_xptr
4949
double HMI_xptr(SEXP ptr1, SEXP ptr2);
5050
RcppExport SEXP _TreeDist_HMI_xptr(SEXP ptr1SEXP, SEXP ptr2SEXP) {
@@ -380,8 +380,8 @@ END_RCPP
380380

381381
static const R_CallMethodDef CallEntries[] = {
382382
{"_TreeDist_COMCLUST", (DL_FUNC) &_TreeDist_COMCLUST, 1},
383-
{"_TreeDist_consensus_info", (DL_FUNC) &_TreeDist_consensus_info, 3},
384383
{"_TreeDist_robinson_foulds_all_pairs", (DL_FUNC) &_TreeDist_robinson_foulds_all_pairs, 1},
384+
{"_TreeDist_consensus_info", (DL_FUNC) &_TreeDist_consensus_info, 3},
385385
{"_TreeDist_HMI_xptr", (DL_FUNC) &_TreeDist_HMI_xptr, 2},
386386
{"_TreeDist_HH_xptr", (DL_FUNC) &_TreeDist_HH_xptr, 1},
387387
{"_TreeDist_EHMI_xptr", (DL_FUNC) &_TreeDist_EHMI_xptr, 4},

src/day_1985.cpp

Lines changed: 87 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,40 @@ using namespace Rcpp;
88
#include <TreeTools/root_tree.h> /* for root_on_node() */
99
#include <TreeTools/ClusterTable.h> /* for ClusterTable() */
1010
using TreeTools::ClusterTable;
11-
using TreeTools::ct_max_leaves;
11+
using TreeTools::ct_stack_size;
12+
using TreeTools::ct_stack_threshold;
13+
14+
using TreeTools::ct_max_leaves; /* TODO remove */
1215

13-
#include <array> /* for array */
14-
#include <bitset> /* for bitset */
15-
#include <vector> /* for vector */
1616
#include <cmath> /* for log2(), ceil() */
17-
#include <memory> /* for unique_ptr, make_unique */
1817

1918
struct StackEntry { int32 L, R, N, W; };
2019

2120
// COMCLUSTER computes a strict consensus tree in O(knn).
2221
// COMCLUST requires O(kn).
2322
// trees is a list of objects of class phylo.
2423
// [[Rcpp::export]]
25-
int COMCLUST(List trees) {
24+
int COMCLUST(const List& trees) {
2625

2726
int32 v = 0;
2827
int32 w = 0;
2928
int32 L, R, N, W;
30-
int32 L_i, R_i, N_i, W_i;
31-
29+
3230
ClusterTable X(List(trees(0)));
33-
std::array<int32, TreeTools::ct_max_leaves> S;
34-
35-
for (int32 i = 1; i != trees.length(); i++) {
31+
const int32 n_tip = X.N();
32+
33+
StackEntry* S_ptr;
34+
std::array<StackEntry, TreeTools::ct_stack_threshold> S_stack;
35+
std::vector<StackEntry> S_heap;
36+
37+
if (n_tip <= TreeTools::ct_stack_threshold) {
38+
S_ptr = S_stack.data();
39+
} else {
40+
S_heap.resize(n_tip);
41+
S_ptr = S_heap.data();
42+
}
43+
44+
for (int32 i = 1; i < trees.length(); ++i) {
3645
int32 Spos = 0; // Empty the stack S
3746

3847
X.CLEAR();
@@ -42,20 +51,27 @@ int COMCLUST(List trees) {
4251

4352
do {
4453
if (Ti.is_leaf(v)) {
45-
CT_PUSH(X.ENCODE(v), X.ENCODE(v), 1, 1);
54+
S_ptr[Spos++] = {X.ENCODE(v), X.ENCODE(v), 1, 1};
4655
} else {
47-
CT_POP(L, R, N, W_i);
56+
const StackEntry& top = S_ptr[--Spos];
57+
L = top.L;
58+
R = top.R;
59+
N = top.N;
60+
const int32 W_i = top.W;
61+
4862
W = 1 + W_i;
4963
w = w - W_i;
5064
while (w) {
51-
CT_POP(L_i, R_i, N_i, W_i);
52-
if (L_i < L) L = L_i;
53-
if (R_i > R) R = R_i;
54-
N += N_i;
55-
W += W_i;
56-
w -= W_i;
65+
const StackEntry& next = S_ptr[--Spos];
66+
if (next.L < L) L = next.L;
67+
if (next.R > R) R = next.R;
68+
N += next.N;
69+
W += next.W;
70+
w -= next.W;
5771
};
58-
CT_PUSH(L, R, N, W);
72+
73+
S_ptr[Spos++] = {L, R, N, W};
74+
5975
if (N == R - L + 1) { // L..R is contiguous, and must be tested
6076
X.SETSW(L, R);
6177
}
@@ -69,20 +85,12 @@ int COMCLUST(List trees) {
6985
}
7086

7187
#define IS_LEAF(a) (a) <= n_tip
72-
73-
// COMCLUSTER computes a strict consensus tree in O(knn).
74-
// COMCLUST requires O(kn).
88+
7589
// trees is a list of objects of class phylo, all with the same tip labels
7690
// (try RenumberTips(trees, trees[[1]]))
77-
// [[Rcpp::export]]
78-
double consensus_info(const List trees, const LogicalVector phylo,
79-
const NumericVector p) {
80-
if (p[0] > 1 + 1e-15) { // epsilon catches floating point error
81-
Rcpp::stop("p must be <= 1.0 in consensus_info()");
82-
} else if (p[0] < 0.5) {
83-
Rcpp::stop("p must be >= 0.5 in consensus_info()");
84-
}
85-
91+
template <typename StackContainer>
92+
double calc_consensus_info(const List &trees, const LogicalVector &phylo,
93+
const NumericVector& p, StackContainer& S) {
8694
int32 v = 0;
8795
int32 w = 0;
8896
int32 L;
@@ -93,12 +101,14 @@ double consensus_info(const List trees, const LogicalVector phylo,
93101
int32 R_j;
94102
int32 N_j;
95103
int32 W_j;
104+
96105
const int32 n_trees = trees.length();
97-
106+
98107
std::vector<ClusterTable> tables;
99108
if (std::size_t(n_trees) > tables.max_size()) {
100109
Rcpp::stop("Not enough memory available to compute consensus of so many trees"); // LCOV_EXCL_LINE
101110
}
111+
102112
tables.reserve(n_trees);
103113
for (int32 i = n_trees; i--; ) {
104114
tables.emplace_back(ClusterTable(List(trees(i))));
@@ -109,12 +119,17 @@ double consensus_info(const List trees, const LogicalVector phylo,
109119
(n_trees / 2) + 1 : // Splits must occur in MORE THAN 0.5 to be in majority.
110120
std::ceil(p[0] * n_trees);
111121
const int32 must_occur_before = 1 + n_trees - thresh;
112-
122+
123+
std::array<int32, TreeTools::ct_stack_threshold> split_count_stack;
124+
std::vector<int32> split_count_heap;
125+
int32* split_count;
126+
if (n_tip < TreeTools::ct_stack_threshold) {
127+
split_count = split_count_stack.data();
128+
} else {
129+
split_count_heap.resize(n_tip);
130+
split_count = split_count_heap.data();
131+
}
113132
const bool phylo_info = phylo[0];
114-
115-
std::array<int32, TreeTools::ct_stack_size * TreeTools::ct_max_leaves> S;
116-
std::array<int32, TreeTools::ct_max_leaves> split_count;
117-
118133
double info = 0;
119134

120135
const std::size_t ntip_3 = n_tip - 3;
@@ -125,7 +140,7 @@ double consensus_info(const List trees, const LogicalVector phylo,
125140
}
126141

127142
std::vector<int32> split_size(n_tip);
128-
std::fill(split_count.begin(), split_count.begin() + n_tip, 1);
143+
std::fill(split_count, split_count + n_tip, 1);
129144

130145
for (int32 j = i + 1; j < n_trees; ++j) {
131146

@@ -207,7 +222,7 @@ double consensus_info(const List trees, const LogicalVector phylo,
207222
}
208223

209224
// [[Rcpp::export]]
210-
IntegerVector robinson_foulds_all_pairs(List tables) {
225+
IntegerVector robinson_foulds_all_pairs(const List& tables) {
211226
const int n_trees = static_cast<int>(tables.size());
212227
if (n_trees < 2) return IntegerVector(0);
213228

@@ -304,3 +319,35 @@ IntegerVector robinson_foulds_all_pairs(List tables) {
304319

305320
return shared;
306321
}
322+
323+
// [[Rcpp::export]]
324+
double consensus_info(const List trees, const LogicalVector phylo,
325+
const NumericVector p) {
326+
if (p[0] > 1 + 1e-15) { // epsilon catches floating point error
327+
Rcpp::stop("p must be <= 1.0 in consensus_info()");
328+
} else if (p[0] < 0.5) {
329+
Rcpp::stop("p must be >= 0.5 in consensus_info()");
330+
}
331+
332+
// First, peek at the tree size to determine allocation strategy
333+
// We'll create a temporary ClusterTable just to check the size
334+
try {
335+
TreeTools::ClusterTable temp_table(Rcpp::List(trees(0)));
336+
const int32 n_tip = temp_table.N();
337+
338+
if (n_tip <= ct_stack_threshold) {
339+
// Small tree: use stack-allocated array
340+
std::array<int32, ct_stack_size * ct_stack_threshold> S;
341+
return calc_consensus_info(trees, phylo, p, S);
342+
} else {
343+
// Large tree: use heap-allocated vector
344+
std::vector<int32> S(ct_stack_size * n_tip);
345+
return calc_consensus_info(trees, phylo, p, S);
346+
}
347+
} catch(const std::exception& e) {
348+
Rcpp::stop(e.what());
349+
}
350+
351+
ASSERT(false && "Unreachable code in consensus_tree");
352+
return 0.0;
353+
}

tests/testthat/test-day_1985.cpp.r

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,13 @@ test_that("Day 1985 examples", {
4646
t2 <- PrepareTree("(((2, 4, 5, 7, 9, 10, 12, 13), 1, 14), (6, (8, 11)), 3);")
4747
as.ClusterTable(t2)
4848

49-
expect_equal(2L, COMCLUST(list(t1, t2)))
50-
expect_equal(7L, as.integer(RobinsonFoulds(list(t1, t2))))
49+
expect_equal(COMCLUST(list(t1, t2)), 2L)
50+
expect_equal(as.integer(RobinsonFoulds(list(t1, t2))), 7L)
5151

52+
# Large trees
53+
twoBiggies <- list(BalancedTree(20000), BalancedTree(20000))
54+
expect_no_error(COMCLUST(twoBiggies))
55+
expect_equal(RobinsonFoulds(twoBiggies)[[1]], 0)
5256
})
5357

5458
test_that("RobinsonFoulds() with realistic trees", {

0 commit comments

Comments
 (0)