Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,4 @@ importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
importFrom(vctrs,vec_restore)
importFrom(vctrs,vec_slice)
useDynLib(posterior, .registration = TRUE)
23 changes: 15 additions & 8 deletions R/pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,15 @@ ps_tail <- function(x,
x <- -x
}

ndraws <- length(x)
tail_ids <- seq(ndraws - ndraws_tail + 1, ndraws)
if (is_constant(x)) {
if (tail == "left") {
x <- -x
}
return(list(x = x, k = NA))
}

ord <- sort.int(x, index.return = TRUE)
draws_tail <- ord$x[tail_ids]
tail_info <- .ps_tail_select(x, ndraws_tail)
draws_tail <- tail_info$tail

if (is_constant(draws_tail)) {
if (tail == "left") {
Expand All @@ -512,9 +516,8 @@ ps_tail <- function(x,
return(list(x = x, k = NA))
}

cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values
if (cutoff == ord$x[min(tail_ids)]) {
# cutoff is not smaller than the tail values
cutoff <- tail_info$cutoff
if (cutoff == draws_tail[1]) {
cutoff <- cutoff - .Machine$double.eps
}

Expand All @@ -540,7 +543,7 @@ ps_tail <- function(x,
# truncate at max of raw draws
if (!is.null(smoothed)) {
smoothed[smoothed > max_tail] <- max_tail
x[ord$ix[tail_ids]] <- smoothed
x[tail_info$tail_idx] <- smoothed
}

if (tail == "left") {
Expand All @@ -552,6 +555,10 @@ ps_tail <- function(x,
return(out)
}

.ps_tail_select <- function(x, ndraws_tail) {
.Call(posterior_ps_tail_select, x, as.integer(ndraws_tail))
}

#' Extra Pareto-k diagnostics
#'
#' internal function to calculate the extra diagnostics for a given
Expand Down
1 change: 1 addition & 0 deletions R/posterior-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,5 @@
#' match between two objects involved in a binary operation. Whether this
#' causes a warning can be controlled by this option.
#'
#' @useDynLib posterior, .registration = TRUE
"_PACKAGE"
3 changes: 3 additions & 0 deletions src/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.o
*.so
*.dll
18 changes: 18 additions & 0 deletions src/init.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include <R.h>
#include <Rinternals.h>
#include <stdlib.h> // for NULL
#include <R_ext/Rdynload.h>

/* .Call calls */
extern SEXP posterior_ps_tail_select(SEXP, SEXP);

static const R_CallMethodDef CallEntries[] = {
{"posterior_ps_tail_select", (DL_FUNC) &posterior_ps_tail_select, 2},
{NULL, NULL, 0}
};

void R_init_posterior(DllInfo *dll)
{
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
R_useDynamicSymbols(dll, FALSE);
}
117 changes: 117 additions & 0 deletions src/ps_tail_select.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#include <R.h>
#include <Rinternals.h>
#include <R_ext/Rdynload.h>
#include <stdlib.h>

typedef struct {
double value;
int index;
} Node;

static inline int node_less(Node a, Node b) {
if (a.value < b.value) return 1;
if (a.value > b.value) return 0;
return a.index < b.index;
}

static inline void heap_sift_down(Node *heap, int n, int i) {
Node x = heap[i];

for (;;) {
int left = 2 * i + 1;

if (left >= n) {
break;
}

int child = left;
int right = left + 1;

if (right < n && node_less(heap[right], heap[left])) {
child = right;
}

if (!node_less(heap[child], x)) {
break;
}

heap[i] = heap[child];
i = child;
}

heap[i] = x;
}

static inline Node heap_pop_min(Node *heap, int *n) {
Node out = heap[0];
int last = --(*n);

if (last > 0) {
heap[0] = heap[last];
heap_sift_down(heap, last, 0);
}

return out;
}

SEXP posterior_ps_tail_select(SEXP x, SEXP ndraws_tail) {
const double *xx = REAL(x);
int n = LENGTH(x);
int m = INTEGER(ndraws_tail)[0];

if (m < 1 || m >= n) {
error("Invalid ndraws_tail.");
}

int keep = m + 1;
Node *heap = (Node *) R_alloc((size_t) keep, sizeof *heap);

for (int i = 0; i < keep; ++i) {
heap[i].value = xx[i];
heap[i].index = i + 1;
}

for (int i = keep / 2; i > 0; --i) {
heap_sift_down(heap, keep, i - 1);
}

for (int i = keep; i < n; ++i) {
Node candidate;
candidate.value = xx[i];
candidate.index = i + 1;

if (node_less(heap[0], candidate)) {
heap[0] = candidate;
heap_sift_down(heap, keep, 0);
}
}

SEXP out = PROTECT(allocVector(VECSXP, 3));
SEXP names = PROTECT(allocVector(STRSXP, 3));
SEXP cutoff = PROTECT(allocVector(REALSXP, 1));
SEXP tail = PROTECT(allocVector(REALSXP, m));
SEXP tail_idx = PROTECT(allocVector(INTSXP, m));

int heap_n = keep;

Node node = heap_pop_min(heap, &heap_n);
REAL(cutoff)[0] = node.value;

for (int i = 0; i < m; ++i) {
node = heap_pop_min(heap, &heap_n);
REAL(tail)[i] = node.value;
INTEGER(tail_idx)[i] = node.index;
}

SET_VECTOR_ELT(out, 0, cutoff);
SET_VECTOR_ELT(out, 1, tail);
SET_VECTOR_ELT(out, 2, tail_idx);

SET_STRING_ELT(names, 0, mkChar("cutoff"));
SET_STRING_ELT(names, 1, mkChar("tail"));
SET_STRING_ELT(names, 2, mkChar("tail_idx"));
setAttrib(out, R_NamesSymbol, names);

UNPROTECT(5);
return out;
}
Loading