diff --git a/NAMESPACE b/NAMESPACE index 216cf1ce..0924ac2d 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -590,3 +590,4 @@ importFrom(vctrs,vec_ptype_abbr) importFrom(vctrs,vec_ptype_full) importFrom(vctrs,vec_restore) importFrom(vctrs,vec_slice) +useDynLib(posterior, .registration = TRUE) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 7de50fa9..66102bdd 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -499,11 +499,15 @@ ps_tail <- function(x, x <- -x } - ndraws <- length(x) - tail_ids <- seq(ndraws - ndraws_tail + 1, ndraws) + if (is_constant(x)) { + if (tail == "left") { + x <- -x + } + return(list(x = x, k = NA)) + } - ord <- sort.int(x, index.return = TRUE) - draws_tail <- ord$x[tail_ids] + tail_info <- .ps_tail_select(x, ndraws_tail) + draws_tail <- tail_info$tail if (is_constant(draws_tail)) { if (tail == "left") { @@ -512,9 +516,8 @@ ps_tail <- function(x, return(list(x = x, k = NA)) } - cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values - if (cutoff == ord$x[min(tail_ids)]) { - # cutoff is not smaller than the tail values + cutoff <- tail_info$cutoff + if (cutoff == draws_tail[1]) { cutoff <- cutoff - .Machine$double.eps } @@ -540,7 +543,7 @@ ps_tail <- function(x, # truncate at max of raw draws if (!is.null(smoothed)) { smoothed[smoothed > max_tail] <- max_tail - x[ord$ix[tail_ids]] <- smoothed + x[tail_info$tail_idx] <- smoothed } if (tail == "left") { @@ -552,6 +555,10 @@ ps_tail <- function(x, return(out) } +.ps_tail_select <- function(x, ndraws_tail) { + .Call(posterior_ps_tail_select, x, as.integer(ndraws_tail)) +} + #' Extra Pareto-k diagnostics #' #' internal function to calculate the extra diagnostics for a given diff --git a/R/posterior-package.R b/R/posterior-package.R index ffd6bdc6..6288c308 100644 --- a/R/posterior-package.R +++ b/R/posterior-package.R @@ -65,4 +65,5 @@ #' match between two objects involved in a binary operation. Whether this #' causes a warning can be controlled by this option. #' +#' @useDynLib posterior, .registration = TRUE "_PACKAGE" diff --git a/src/.gitignore b/src/.gitignore new file mode 100644 index 00000000..22034c46 --- /dev/null +++ b/src/.gitignore @@ -0,0 +1,3 @@ +*.o +*.so +*.dll diff --git a/src/init.c b/src/init.c new file mode 100644 index 00000000..b9e07714 --- /dev/null +++ b/src/init.c @@ -0,0 +1,18 @@ +#include +#include +#include // for NULL +#include + +/* .Call calls */ +extern SEXP posterior_ps_tail_select(SEXP, SEXP); + +static const R_CallMethodDef CallEntries[] = { + {"posterior_ps_tail_select", (DL_FUNC) &posterior_ps_tail_select, 2}, + {NULL, NULL, 0} +}; + +void R_init_posterior(DllInfo *dll) +{ + R_registerRoutines(dll, NULL, CallEntries, NULL, NULL); + R_useDynamicSymbols(dll, FALSE); +} diff --git a/src/ps_tail_select.c b/src/ps_tail_select.c new file mode 100644 index 00000000..45876bfa --- /dev/null +++ b/src/ps_tail_select.c @@ -0,0 +1,117 @@ +#include +#include +#include +#include + +typedef struct { + double value; + int index; +} Node; + +static inline int node_less(Node a, Node b) { + if (a.value < b.value) return 1; + if (a.value > b.value) return 0; + return a.index < b.index; +} + +static inline void heap_sift_down(Node *heap, int n, int i) { + Node x = heap[i]; + + for (;;) { + int left = 2 * i + 1; + + if (left >= n) { + break; + } + + int child = left; + int right = left + 1; + + if (right < n && node_less(heap[right], heap[left])) { + child = right; + } + + if (!node_less(heap[child], x)) { + break; + } + + heap[i] = heap[child]; + i = child; + } + + heap[i] = x; +} + +static inline Node heap_pop_min(Node *heap, int *n) { + Node out = heap[0]; + int last = --(*n); + + if (last > 0) { + heap[0] = heap[last]; + heap_sift_down(heap, last, 0); + } + + return out; +} + +SEXP posterior_ps_tail_select(SEXP x, SEXP ndraws_tail) { + const double *xx = REAL(x); + int n = LENGTH(x); + int m = INTEGER(ndraws_tail)[0]; + + if (m < 1 || m >= n) { + error("Invalid ndraws_tail."); + } + + int keep = m + 1; + Node *heap = (Node *) R_alloc((size_t) keep, sizeof *heap); + + for (int i = 0; i < keep; ++i) { + heap[i].value = xx[i]; + heap[i].index = i + 1; + } + + for (int i = keep / 2; i > 0; --i) { + heap_sift_down(heap, keep, i - 1); + } + + for (int i = keep; i < n; ++i) { + Node candidate; + candidate.value = xx[i]; + candidate.index = i + 1; + + if (node_less(heap[0], candidate)) { + heap[0] = candidate; + heap_sift_down(heap, keep, 0); + } + } + + SEXP out = PROTECT(allocVector(VECSXP, 3)); + SEXP names = PROTECT(allocVector(STRSXP, 3)); + SEXP cutoff = PROTECT(allocVector(REALSXP, 1)); + SEXP tail = PROTECT(allocVector(REALSXP, m)); + SEXP tail_idx = PROTECT(allocVector(INTSXP, m)); + + int heap_n = keep; + + Node node = heap_pop_min(heap, &heap_n); + REAL(cutoff)[0] = node.value; + + for (int i = 0; i < m; ++i) { + node = heap_pop_min(heap, &heap_n); + REAL(tail)[i] = node.value; + INTEGER(tail_idx)[i] = node.index; + } + + SET_VECTOR_ELT(out, 0, cutoff); + SET_VECTOR_ELT(out, 1, tail); + SET_VECTOR_ELT(out, 2, tail_idx); + + SET_STRING_ELT(names, 0, mkChar("cutoff")); + SET_STRING_ELT(names, 1, mkChar("tail")); + SET_STRING_ELT(names, 2, mkChar("tail_idx")); + setAttrib(out, R_NamesSymbol, names); + + UNPROTECT(5); + return out; +}