@@ -8,31 +8,40 @@ using namespace Rcpp;
88#include < TreeTools/root_tree.h> /* for root_on_node() */
99#include < TreeTools/ClusterTable.h> /* for ClusterTable() */
1010using TreeTools::ClusterTable;
11- using TreeTools::ct_max_leaves;
11+ using TreeTools::ct_stack_size;
12+ using TreeTools::ct_stack_threshold;
13+
14+ using TreeTools::ct_max_leaves; /* TODO remove */
1215
13- #include < array> /* for array */
14- #include < bitset> /* for bitset */
15- #include < vector> /* for vector */
1616#include < cmath> /* for log2(), ceil() */
17- #include < memory> /* for unique_ptr, make_unique */
1817
1918struct StackEntry { int32 L, R, N, W; };
2019
2120// COMCLUSTER computes a strict consensus tree in O(knn).
2221// COMCLUST requires O(kn).
2322// trees is a list of objects of class phylo.
2423// [[Rcpp::export]]
25- int COMCLUST (List trees) {
24+ int COMCLUST (const List& trees) {
2625
2726 int32 v = 0 ;
2827 int32 w = 0 ;
2928 int32 L, R, N, W;
30- int32 L_i, R_i, N_i, W_i;
31-
29+
3230 ClusterTable X (List (trees (0 )));
33- std::array<int32, TreeTools::ct_max_leaves> S;
34-
35- for (int32 i = 1 ; i != trees.length (); i++) {
31+ const int32 n_tip = X.N ();
32+
33+ StackEntry* S_ptr;
34+ std::array<StackEntry, TreeTools::ct_stack_threshold> S_stack;
35+ std::vector<StackEntry> S_heap;
36+
37+ if (n_tip <= TreeTools::ct_stack_threshold) {
38+ S_ptr = S_stack.data ();
39+ } else {
40+ S_heap.resize (n_tip);
41+ S_ptr = S_heap.data ();
42+ }
43+
44+ for (int32 i = 1 ; i < trees.length (); ++i) {
3645 int32 Spos = 0 ; // Empty the stack S
3746
3847 X.CLEAR ();
@@ -42,20 +51,27 @@ int COMCLUST(List trees) {
4251
4352 do {
4453 if (Ti.is_leaf (v)) {
45- CT_PUSH ( X.ENCODE (v), X.ENCODE (v), 1 , 1 ) ;
54+ S_ptr[Spos++] = { X.ENCODE (v), X.ENCODE (v), 1 , 1 } ;
4655 } else {
47- CT_POP (L, R, N, W_i);
56+ const StackEntry& top = S_ptr[--Spos];
57+ L = top.L ;
58+ R = top.R ;
59+ N = top.N ;
60+ const int32 W_i = top.W ;
61+
4862 W = 1 + W_i;
4963 w = w - W_i;
5064 while (w) {
51- CT_POP (L_i, R_i, N_i, W_i) ;
52- if (L_i < L) L = L_i ;
53- if (R_i > R) R = R_i ;
54- N += N_i ;
55- W += W_i ;
56- w -= W_i ;
65+ const StackEntry& next = S_ptr[--Spos] ;
66+ if (next. L < L) L = next. L ;
67+ if (next. R > R) R = next. R ;
68+ N += next. N ;
69+ W += next. W ;
70+ w -= next. W ;
5771 };
58- CT_PUSH (L, R, N, W);
72+
73+ S_ptr[Spos++] = {L, R, N, W};
74+
5975 if (N == R - L + 1 ) { // L..R is contiguous, and must be tested
6076 X.SETSW (L, R);
6177 }
@@ -69,20 +85,12 @@ int COMCLUST(List trees) {
6985}
7086
7187#define IS_LEAF (a ) (a) <= n_tip
72-
73- // COMCLUSTER computes a strict consensus tree in O(knn).
74- // COMCLUST requires O(kn).
88+
7589// trees is a list of objects of class phylo, all with the same tip labels
7690// (try RenumberTips(trees, trees[[1]]))
77- // [[Rcpp::export]]
78- double consensus_info (const List trees, const LogicalVector phylo,
79- const NumericVector p) {
80- if (p[0 ] > 1 + 1e-15 ) { // epsilon catches floating point error
81- Rcpp::stop (" p must be <= 1.0 in consensus_info()" );
82- } else if (p[0 ] < 0.5 ) {
83- Rcpp::stop (" p must be >= 0.5 in consensus_info()" );
84- }
85-
91+ template <typename StackContainer>
92+ double calc_consensus_info (const List &trees, const LogicalVector &phylo,
93+ const NumericVector& p, StackContainer& S) {
8694 int32 v = 0 ;
8795 int32 w = 0 ;
8896 int32 L;
@@ -93,12 +101,14 @@ double consensus_info(const List trees, const LogicalVector phylo,
93101 int32 R_j;
94102 int32 N_j;
95103 int32 W_j;
104+
96105 const int32 n_trees = trees.length ();
97-
106+
98107 std::vector<ClusterTable> tables;
99108 if (std::size_t (n_trees) > tables.max_size ()) {
100109 Rcpp::stop (" Not enough memory available to compute consensus of so many trees" ); // LCOV_EXCL_LINE
101110 }
111+
102112 tables.reserve (n_trees);
103113 for (int32 i = n_trees; i--; ) {
104114 tables.emplace_back (ClusterTable (List (trees (i))));
@@ -109,12 +119,17 @@ double consensus_info(const List trees, const LogicalVector phylo,
109119 (n_trees / 2 ) + 1 : // Splits must occur in MORE THAN 0.5 to be in majority.
110120 std::ceil (p[0 ] * n_trees);
111121 const int32 must_occur_before = 1 + n_trees - thresh;
112-
122+
123+ std::array<int32, TreeTools::ct_stack_threshold> split_count_stack;
124+ std::vector<int32> split_count_heap;
125+ int32* split_count;
126+ if (n_tip < TreeTools::ct_stack_threshold) {
127+ split_count = split_count_stack.data ();
128+ } else {
129+ split_count_heap.resize (n_tip);
130+ split_count = split_count_heap.data ();
131+ }
113132 const bool phylo_info = phylo[0 ];
114-
115- std::array<int32, TreeTools::ct_stack_size * TreeTools::ct_max_leaves> S;
116- std::array<int32, TreeTools::ct_max_leaves> split_count;
117-
118133 double info = 0 ;
119134
120135 const std::size_t ntip_3 = n_tip - 3 ;
@@ -125,7 +140,7 @@ double consensus_info(const List trees, const LogicalVector phylo,
125140 }
126141
127142 std::vector<int32> split_size (n_tip);
128- std::fill (split_count. begin () , split_count. begin () + n_tip, 1 );
143+ std::fill (split_count, split_count + n_tip, 1 );
129144
130145 for (int32 j = i + 1 ; j < n_trees; ++j) {
131146
@@ -207,7 +222,7 @@ double consensus_info(const List trees, const LogicalVector phylo,
207222}
208223
209224// [[Rcpp::export]]
210- IntegerVector robinson_foulds_all_pairs (List tables) {
225+ IntegerVector robinson_foulds_all_pairs (const List& tables) {
211226 const int n_trees = static_cast <int >(tables.size ());
212227 if (n_trees < 2 ) return IntegerVector (0 );
213228
@@ -304,3 +319,35 @@ IntegerVector robinson_foulds_all_pairs(List tables) {
304319
305320 return shared;
306321}
322+
323+ // [[Rcpp::export]]
324+ double consensus_info (const List trees, const LogicalVector phylo,
325+ const NumericVector p) {
326+ if (p[0 ] > 1 + 1e-15 ) { // epsilon catches floating point error
327+ Rcpp::stop (" p must be <= 1.0 in consensus_info()" );
328+ } else if (p[0 ] < 0.5 ) {
329+ Rcpp::stop (" p must be >= 0.5 in consensus_info()" );
330+ }
331+
332+ // First, peek at the tree size to determine allocation strategy
333+ // We'll create a temporary ClusterTable just to check the size
334+ try {
335+ TreeTools::ClusterTable temp_table (Rcpp::List (trees (0 )));
336+ const int32 n_tip = temp_table.N ();
337+
338+ if (n_tip <= ct_stack_threshold) {
339+ // Small tree: use stack-allocated array
340+ std::array<int32, ct_stack_size * ct_stack_threshold> S;
341+ return calc_consensus_info (trees, phylo, p, S);
342+ } else {
343+ // Large tree: use heap-allocated vector
344+ std::vector<int32> S (ct_stack_size * n_tip);
345+ return calc_consensus_info (trees, phylo, p, S);
346+ }
347+ } catch (const std::exception& e) {
348+ Rcpp::stop (e.what ());
349+ }
350+
351+ ASSERT (false && " Unreachable code in consensus_tree" );
352+ return 0.0 ;
353+ }
0 commit comments