diff --git a/R/loo_compare.R b/R/loo_compare.R index fdf0e368..9c90b924 100644 --- a/R/loo_compare.R +++ b/R/loo_compare.R @@ -2,9 +2,6 @@ #' #' @description Compare fitted models based on [ELPD][loo-glossary]. #' -#' By default the print method shows only the most important information. Use -#' `print(..., simplify=FALSE)` to print a more detailed summary. -#' #' @export #' @param x An object of class `"loo"` or a list of such objects. If a list is #' used then the list names will be used as the model names in the output. See @@ -12,8 +9,8 @@ #' @param ... Additional objects of class `"loo"`, if not passed in as a single #' list. #' -#' @return A matrix with class `"compare.loo"` that has its own -#' print method. See the **Details** section. +#' @return A data frame with class `"compare.loo"` that has its own +#' print method. See the **Details** and **Examples** sections. #' #' @details #' When comparing two fitted models, we can estimate the difference in their @@ -21,14 +18,14 @@ #' [`elpd_loo`][loo-glossary] or `elpd_waic` (or multiplied by \eqn{-2}, if #' desired, to be on the deviance scale). #' -#' When using `loo_compare()`, the returned matrix will have one row per model -#' and several columns of estimates. The values in the -#' [`elpd_diff`][loo-glossary] and [`se_diff`][loo-glossary] columns of the -#' returned matrix are computed by making pairwise comparisons between each -#' model and the model with the largest ELPD (the model in the first row). For -#' this reason the `elpd_diff` column will always have the value `0` in the -#' first row (i.e., the difference between the preferred model and itself) and -#' negative values in subsequent rows for the remaining models. +#' ## `elpd_diff` and `se_diff` +#' When using `loo_compare()`, the returned data frame will have one row per +#' model and several columns of estimates. The values of +#' [`elpd_diff`][loo-glossary] and [`se_diff`][loo-glossary] are computed by +#' making pairwise comparisons between each model and the model with the +#' largest ELPD (the model listed first). Therefore, the first `elpd_diff` +#' value will always be `0` (i.e., the difference between the preferred model +#' and itself) and the rest of the values will be negative. #' #' To compute the standard error of the difference in [ELPD][loo-glossary] --- #' which should not be expected to equal the difference of the standard errors @@ -41,9 +38,45 @@ #' standard approach of comparing differences of deviances to a Chi-squared #' distribution, a practice derived for Gaussian linear models or #' asymptotically, and which only applies to nested models in any case. -#' Sivula et al. (2022) discuss the conditions when the normal -#' approximation used for SE and `se_diff` is good. #' +#' ## `p_worse`, `diag_diff`, and `diag_elpd` +#' The values in the `p_worse` column show the probability of each model +#' having worse ELPD than the best model. These probabilities are computed +#' with a normal approximation using the values from `elpd_diff` and +#' `se_diff`. Sivula et al. (2025) present the conditions when the normal +#' approximation used for SE and `se_diff` is good, and the column +#' `diag_diff` contains possible diagnostic messages: +#' +#' * `N < 100` (small data) +#' * `|elpd_diff| < 4` (models make similar predictions) +#' * `k_diff > 0.5` (possible outliers) +#' +#' If any of these diagnostic messages is shown, the error distribution is +#' skewed or thick tailed and the normal approximation based on `elpd_diff` +#' and `se_diff` is not well calibrated. In that case, the probabilities +#' `p_worse` are likely to be too large (small data or similar predictions) or +#' too small (outliers). However, `elpd_diff` and `se_diff` will still be +#' indicative of the differences and uncertainties (for example, if +#' `|elpd_diff|` is many times larger than `se_diff` the difference is quite +#' certain). +#' +#' The `k_diff` value for the `diag_diff` column is computed using the +#' pointwise ELPD differences (and is different from the Pareto k's in +#' PSIS-LOO diagnostic). While `k_diff > 0.5` indicates the *possibility* of +#' outliers, it is also possible that both models compared seem to be well +#' specified based on model checking, but the pointwise ELPD differences have +#' such thick tails that the normal approximation for the sum is not good +#' (Vehtari et al., 2024). A threshold of 0.5 is used for `k_diff` as we do +#' not do automatic Pareto smoothing for the pointwise differences (Vehtari et +#' al., 2024). +#' +#' The column `diag_elpd` shows the PSIS-LOO Pareto k diagnostic for the +#' pointwise ELPD computations for each model. If `K k_psis > 0.7` is shown, +#' where `K` is the number of high high Pareto k values in the PSIS +#' computation, then there may be significant bias in `elpd_diff` favoring +#' models with a large number of high Pareto k values. +#' +#' ## Warnings for many model comparisons #' If more than \eqn{11} models are compared, we internally recompute the model #' differences using the median model by ELPD as the baseline model. We then #' estimate whether the differences in predictive performance are potentially @@ -52,7 +85,7 @@ #' selection process. In that case users are recommended to avoid model #' selection based on LOO-CV, and instead to favor model averaging/stacking or #' projection predictive inference. -#' +#' #' @seealso #' * The [FAQ page](https://mc-stan.org/loo/articles/online-only/faq.html) on #' the __loo__ website for answers to frequently asked questions. @@ -68,12 +101,8 @@ #' comp <- loo_compare(loo1, loo2, loo3) #' print(comp, digits = 2) #' -#' # show more details with simplify=FALSE -#' # (will be the same for all models in this artificial example) -#' print(comp, simplify = FALSE, digits = 3) -#' #' # can use a list of objects with custom names -#' # will use apple, banana, and cherry, as the names in the output +#' # the names will be used in the output #' loo_compare(list("apple" = loo1, "banana" = loo2, "cherry" = loo3)) #' #' \dontrun{ @@ -101,54 +130,83 @@ loo_compare.default <- function(x, ...) { loos <- x } - # If subsampling is used + # if subsampling is used if (any(sapply(loos, inherits, "psis_loo_ss"))) { return(loo_compare.psis_loo_ss_list(loos)) } + # run pre-comparison checks loo_compare_checks(loos) + # compute elpd_diff and se_elpd_diff relative to best model comp <- loo_compare_matrix(loos) ord <- loo_compare_order(loos) - - # compute elpd_diff and se_elpd_diff relative to best model rnms <- rownames(comp) diffs <- mapply(FUN = elpd_diffs, loos[ord[1]], loos[ord]) + colnames(diffs) <- rnms elpd_diff <- apply(diffs, 2, sum) se_diff <- apply(diffs, 2, se_elpd_diff) - comp <- cbind(elpd_diff = elpd_diff, se_diff = se_diff, comp) - rownames(comp) <- rnms - # run order statistics-based checks on models + # compute probabilities that a model has worse elpd than the best model + # using a normal approximation (Sivula et al., 2025) + p_worse <- stats::pnorm(0, elpd_diff, se_diff) + p_worse[elpd_diff == 0] <- NA + + comp <- cbind( + data.frame( + model = rnms, + elpd_diff = elpd_diff, + se_diff = se_diff, + p_worse = p_worse, + diag_diff = diag_diff(nrow(diffs), elpd_diff), + diag_elpd = diag_elpd(loos[ord]) + ), + as.data.frame(comp) + ) + rownames(comp) <- NULL + + # run order statistics-based checks for many model comparisons loo_order_stat_check(loos, ord) class(comp) <- c("compare.loo", class(comp)) - return(comp) + comp } #' @rdname loo_compare #' @export #' @param digits For the print method only, the number of digits to use when #' printing. -#' @param simplify For the print method only, should only the essential columns -#' of the summary matrix be printed? The entire matrix is always returned, but -#' by default only the most important columns are printed. -print.compare.loo <- function(x, ..., digits = 1, simplify = TRUE) { - xcopy <- x - if (inherits(xcopy, "old_compare.loo")) { - if (NCOL(xcopy) >= 2 && simplify) { - patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$" - xcopy <- xcopy[, grepl(patts, colnames(xcopy))] - } - } else if (NCOL(xcopy) >= 2 && simplify) { - xcopy <- xcopy[, c("elpd_diff", "se_diff")] +#' @param p_worse For the print method only, should we include the normal +#' approximation based probability of each model having worse performance than +#' the best model? The default is `TRUE`. +print.compare.loo <- function(x, ..., digits = 1, p_worse = TRUE) { + if (inherits(x, "old_compare.loo")) { + return(unclass(x)) + } + if (!inherits(x, "data.frame")) { + class(x) <- c(class(x), "data.frame") + } + if (!all(c("model", "elpd_diff", "se_diff") %in% colnames(x))) { + print(as.data.frame(x)) + return(x) } - print(.fr(xcopy, digits), quote = FALSE) + x2 <- cbind( + model = x$model, + .fr(x[, c("elpd_diff", "se_diff")], digits) + ) + if (p_worse && "p_worse" %in% colnames(x)) { + x2 <- cbind( + x2, + p_worse = .fr(x[, "p_worse"], digits = 2), + diag_diff = x[, "diag_diff"], + diag_elpd = x[, "diag_elpd"] + ) + } + print(x2, quote = FALSE, row.names = FALSE) invisible(x) } - # internal ---------------------------------------------------------------- #' Compute pointwise elpd differences @@ -172,7 +230,6 @@ se_elpd_diff <- function(diffs) { sqrt(N) * sd(diffs) } - #' Perform checks on `"loo"` objects before comparison #' @noRd #' @param loos List of `"loo"` objects. @@ -227,7 +284,6 @@ loo_compare_checks <- function(loos) { #' Find the model names associated with `"loo"` objects #' #' @export -#' @keywords internal #' @param x List of `"loo"` objects. #' @return Character vector of model names the same length as `x.` #' @@ -256,7 +312,6 @@ find_model_names <- function(x) { #' Compute the loo_compare matrix -#' @keywords internal #' @noRd #' @param loos List of `"loo"` objects. loo_compare_matrix <- function(loos){ @@ -278,7 +333,6 @@ loo_compare_matrix <- function(loos){ #' Computes the order of loos for comparison #' @noRd -#' @keywords internal #' @param loos List of `"loo"` objects. loo_compare_order <- function(loos){ tmp <- sapply(loos, function(x) { @@ -293,7 +347,6 @@ loo_compare_order <- function(loos){ #' Perform checks on `"loo"` objects __after__ comparison #' @noRd -#' @keywords internal #' @param loos List of `"loo"` objects. #' @param ord List of `"loo"` object orderings. #' @return Nothing, just possibly throws errors/warnings. @@ -335,14 +388,12 @@ loo_order_stat_check <- function(loos, ord) { #' Returns the middle index of a vector #' @noRd -#' @keywords internal #' @param vec A vector. #' @return Integer index value. middle_idx <- function(vec) floor(length(vec) / 2) #' Computes maximum order statistic from K Gaussians #' @noRd -#' @keywords internal #' @param K Number of Gaussians. #' @param c Scaling of the order statistic. #' @return Numeric expected maximum from K samples from a Gaussian with mean @@ -350,3 +401,44 @@ middle_idx <- function(vec) floor(length(vec) / 2) order_stat_heuristic <- function(K, c) { qnorm(p = 1 - 1 / (K * 2), mean = 0, sd = c) } + +#' Count number of high Pareto k values in PSIS-LOO and create diagnostic message +#' @noRd +#' @param loos Ordered list of loo objects. +#' @return Character vector of diagnostic messages. +diag_elpd <- function(loos) { + sapply(loos, function(loo) { + k <- loo$diagnostics[["pareto_k"]] + if (is.null(k)) { + out <- "" + } else { + S <- dim(loo)[1] + khat_threshold <- ps_khat_threshold(S) + K <- sum(k > khat_threshold) + out <- ifelse(K == 0, "", paste0(K, " k_psis > ", round(khat_threshold, 2))) + } + out + }) +} + +#' Create diagnostic for elpd differences +#' @noRd +#' @param N Number of data points. +#' @param elpd_diff Vector of elpd differences. +#' @return Character vector of diagnostic messages. +diag_diff <- function(N, elpd_diff) { + if (N < 100) { + diag_diff <- rep("N < 100", length(elpd_diff)) + diag_diff[elpd_diff == 0] <- "" + } else { + diag_diff <- rep("", length(elpd_diff)) + diag_diff[elpd_diff > -4 & elpd_diff != 0] <- "|elpd_diff| < 4" + k_diff <- rep(NA, length(elpd_diff)) + k_diff[elpd_diff != 0] <- apply( + diffs[, elpd_diff != 0, drop = FALSE], 2, + function(x) ifelse(length(unique(x)) <= 20, NA, posterior::pareto_khat(x, tail = "both") + )) + diag_diff[k_diff > 0.5] <- "k_diff > 0.5" + } + diag_diff +} diff --git a/R/loo_compare.psis_loo_ss_list.R b/R/loo_compare.psis_loo_ss_list.R index acd0690b..9c0db564 100644 --- a/R/loo_compare.psis_loo_ss_list.R +++ b/R/loo_compare.psis_loo_ss_list.R @@ -173,14 +173,9 @@ loo_compare_checks.psis_loo_ss_list <- function(loos) { #' @rdname loo_compare #' @export -print.compare.loo_ss <- function(x, ..., digits = 1, simplify = TRUE) { +print.compare.loo_ss <- function(x, ..., digits = 1) { xcopy <- x - if (inherits(xcopy, "old_compare.loo")) { - if (NCOL(xcopy) >= 2 && simplify) { - patts <- "^elpd_|^se_diff|^p_|^waic$|^looic$" - xcopy <- xcopy[, grepl(patts, colnames(xcopy))] - } - } else if (NCOL(xcopy) >= 2 && simplify) { + if (NCOL(xcopy) >= 2) { xcopy <- xcopy[, c("elpd_diff", "se_diff", "subsampling_se_diff")] } print(.fr(xcopy, digits), quote = FALSE) diff --git a/man/find_model_names.Rd b/man/find_model_names.Rd index cdb76b80..70a79d58 100644 --- a/man/find_model_names.Rd +++ b/man/find_model_names.Rd @@ -15,4 +15,3 @@ Character vector of model names the same length as \code{x.} \description{ Find the model names associated with \code{"loo"} objects } -\keyword{internal} diff --git a/man/loo_compare.Rd b/man/loo_compare.Rd index bd780aa5..da0bf9cc 100644 --- a/man/loo_compare.Rd +++ b/man/loo_compare.Rd @@ -11,9 +11,9 @@ loo_compare(x, ...) \method{loo_compare}{default}(x, ...) -\method{print}{compare.loo}(x, ..., digits = 1, simplify = TRUE) +\method{print}{compare.loo}(x, ..., digits = 1, p_worse = TRUE) -\method{print}{compare.loo_ss}(x, ..., digits = 1, simplify = TRUE) +\method{print}{compare.loo_ss}(x, ..., digits = 1) } \arguments{ \item{x}{An object of class \code{"loo"} or a list of such objects. If a list is @@ -26,34 +26,31 @@ list.} \item{digits}{For the print method only, the number of digits to use when printing.} -\item{simplify}{For the print method only, should only the essential columns -of the summary matrix be printed? The entire matrix is always returned, but -by default only the most important columns are printed.} +\item{p_worse}{For the print method only, should we include the normal +approximation based probability of each model having worse performance than +the best model? The default is \code{TRUE}.} } \value{ -A matrix with class \code{"compare.loo"} that has its own -print method. See the \strong{Details} section. +A data frame with class \code{"compare.loo"} that has its own +print method. See the \strong{Details} and \strong{Examples} sections. } \description{ Compare fitted models based on \link[=loo-glossary]{ELPD}. - -By default the print method shows only the most important information. Use -\code{print(..., simplify=FALSE)} to print a more detailed summary. } \details{ When comparing two fitted models, we can estimate the difference in their expected predictive accuracy by the difference in \code{\link[=loo-glossary]{elpd_loo}} or \code{elpd_waic} (or multiplied by \eqn{-2}, if desired, to be on the deviance scale). +\subsection{\code{elpd_diff} and \code{se_diff}}{ -When using \code{loo_compare()}, the returned matrix will have one row per model -and several columns of estimates. The values in the -\code{\link[=loo-glossary]{elpd_diff}} and \code{\link[=loo-glossary]{se_diff}} columns of the -returned matrix are computed by making pairwise comparisons between each -model and the model with the largest ELPD (the model in the first row). For -this reason the \code{elpd_diff} column will always have the value \code{0} in the -first row (i.e., the difference between the preferred model and itself) and -negative values in subsequent rows for the remaining models. +When using \code{loo_compare()}, the returned data frame will have one row per +model and several columns of estimates. The values of +\code{\link[=loo-glossary]{elpd_diff}} and \code{\link[=loo-glossary]{se_diff}} are computed by +making pairwise comparisons between each model and the model with the +largest ELPD (the model listed first). Therefore, the first \code{elpd_diff} +value will always be \code{0} (i.e., the difference between the preferred model +and itself) and the rest of the values will be negative. To compute the standard error of the difference in \link[=loo-glossary]{ELPD} --- which should not be expected to equal the difference of the standard errors @@ -66,8 +63,49 @@ better sense of uncertainty than what is obtained using the current standard approach of comparing differences of deviances to a Chi-squared distribution, a practice derived for Gaussian linear models or asymptotically, and which only applies to nested models in any case. -Sivula et al. (2022) discuss the conditions when the normal -approximation used for SE and \code{se_diff} is good. +} + +\subsection{\code{p_worse}, \code{diag_diff}, and \code{diag_elpd}}{ + +The values in the \code{p_worse} column show the probability of each model +having worse ELPD than the best model. These probabilities are computed +with a normal approximation using the values from \code{elpd_diff} and +\code{se_diff}. Sivula et al. (2025) present the conditions when the normal +approximation used for SE and \code{se_diff} is good, and the column +\code{diag_diff} contains possible diagnostic messages: +\itemize{ +\item \code{N < 100} (small data) +\item \verb{|elpd_diff| < 4} (models make similar predictions) +\item \code{k_diff > 0.5} (possible outliers) +} + +If any of these diagnostic messages is shown, the error distribution is +skewed or thick tailed and the normal approximation based on \code{elpd_diff} +and \code{se_diff} is not well calibrated. In that case, the probabilities +\code{p_worse} are likely to be too large (small data or similar predictions) or +too small (outliers). However, \code{elpd_diff} and \code{se_diff} will still be +indicative of the differences and uncertainties (for example, if +\verb{|elpd_diff|} is many times larger than \code{se_diff} the difference is quite +certain). + +The \code{k_diff} value for the \code{diag_diff} column is computed using the +pointwise ELPD differences (and is different from the Pareto k's in +PSIS-LOO diagnostic). While \code{k_diff > 0.5} indicates the \emph{possibility} of +outliers, it is also possible that both models compared seem to be well +specified based on model checking, but the pointwise ELPD differences have +such thick tails that the normal approximation for the sum is not good +(Vehtari et al., 2024). A threshold of 0.5 is used for \code{k_diff} as we do +not do automatic Pareto smoothing for the pointwise differences (Vehtari et +al., 2024). + +The column \code{diag_elpd} shows the PSIS-LOO Pareto k diagnostic for the +pointwise ELPD computations for each model. If \verb{K k_psis > 0.7} is shown, +where \code{K} is the number of high high Pareto k values in the PSIS +computation, then there may be significant bias in \code{elpd_diff} favoring +models with a large number of high Pareto k values. +} + +\subsection{Warnings for many model comparisons}{ If more than \eqn{11} models are compared, we internally recompute the model differences using the median model by ELPD as the baseline model. We then @@ -78,6 +116,7 @@ selection process. In that case users are recommended to avoid model selection based on LOO-CV, and instead to favor model averaging/stacking or projection predictive inference. } +} \examples{ # very artificial example, just for demonstration! LL <- example_loglik_array() @@ -88,12 +127,8 @@ loo3 <- loo(LL + 2) # should be best model when compared comp <- loo_compare(loo1, loo2, loo3) print(comp, digits = 2) -# show more details with simplify=FALSE -# (will be the same for all models in this artificial example) -print(comp, simplify = FALSE, digits = 3) - # can use a list of objects with custom names -# will use apple, banana, and cherry, as the names in the output +# the names will be used in the output loo_compare(list("apple" = loo1, "banana" = loo2, "cherry" = loo3)) \dontrun{ diff --git a/tests/testthat/_snaps/compare.md b/tests/testthat/_snaps/compare.md index cf27900e..1c23a744 100644 --- a/tests/testthat/_snaps/compare.md +++ b/tests/testthat/_snaps/compare.md @@ -1,44 +1,89 @@ # loo_compare returns expected results (2 models) - WAoAAAACAAQFAAACAwAAAAMOAAAAEAAAAAAAAAAAwBA6U1+cRe4AAAAAAAAAAD+2ake0LxMB - wFTh8N3JQljAVeWWE8MGuUARCD2zEXBfQBEalRIN2T9ACijAYdW5U0AmZ5XrANCKP/H9Zexy - 814/8ZtgnG1nx0Bk4fDdyUJYQGXllhPDBrlAIQg9sxFwX0AhGpUSDdk/AAAEAgAAAAEABAAJ - AAAAA2RpbQAAAA0AAAACAAAAAgAAAAgAAAQCAAAAAQAEAAkAAAAIZGltbmFtZXMAAAATAAAA - AgAAABAAAAACAAQACQAAAAZtb2RlbDEABAAJAAAABm1vZGVsMgAAABAAAAAIAAQACQAAAAll - bHBkX2RpZmYABAAJAAAAB3NlX2RpZmYABAAJAAAACWVscGRfd2FpYwAEAAkAAAAMc2VfZWxw - ZF93YWljAAQACQAAAAZwX3dhaWMABAAJAAAACXNlX3Bfd2FpYwAEAAkAAAAEd2FpYwAEAAkA - AAAHc2Vfd2FpYwAABAIAAAABAAQACQAAAAVjbGFzcwAAABAAAAADAAQACQAAAAtjb21wYXJl - LmxvbwAEAAkAAAAGbWF0cml4AAQACQAAAAVhcnJheQAAAP4= + WAoAAAACAAQEAgACAwAAAAMTAAAADAAAABAAAAACAAQACQAAAAZtb2RlbDEABAAJAAAABm1v + ZGVsMgAAAA4AAAACAAAAAAAAAAAAAAAAAAAAAAAAAA4AAAACAAAAAAAAAAAAAAAAAAAAAAAA + AA4AAAACf/AAAAAAB6J/8AAAAAAHogAAABAAAAACAAQACQAAAAAABAAJAAAAAAAAABAAAAAC + AAQACQAAAAAABAAJAAAAAAAAAA4AAAACwFTh8N3JQljAVOHw3clCWAAAAA4AAAACQBEIPbMR + cF9AEQg9sxFwXwAAAA4AAAACQAoowGHVuVNACijAYdW5UwAAAA4AAAACP/H9Zexy814/8f1l + 7HLzXgAAAA4AAAACQGTh8N3JQlhAZOHw3clCWAAAAA4AAAACQCEIPbMRcF9AIQg9sxFwXwAA + BAIAAAABAAQACQAAAAVuYW1lcwAAABAAAAAMAAQACQAAAAVtb2RlbAAEAAkAAAAJZWxwZF9k + aWZmAAQACQAAAAdzZV9kaWZmAAQACQAAAAdwX3dvcnNlAAQACQAAAAlkaWFnX2RpZmYABAAJ + AAAACWRpYWdfZWxwZAAEAAkAAAAJZWxwZF93YWljAAQACQAAAAxzZV9lbHBkX3dhaWMABAAJ + AAAABnBfd2FpYwAEAAkAAAAJc2VfcF93YWljAAQACQAAAAR3YWljAAQACQAAAAdzZV93YWlj + AAAEAgAAAAEABAAJAAAABWNsYXNzAAAAEAAAAAIABAAJAAAAC2NvbXBhcmUubG9vAAQACQAA + AApkYXRhLmZyYW1lAAAEAgAAAAEABAAJAAAACXJvdy5uYW1lcwAAAA0AAAACgAAAAP////4A + AAD+ -# loo_compare returns expected result (3 models) +--- - WAoAAAACAAQFAAACAwAAAAMOAAAAGAAAAAAAAAAAwBA6U1+cRe7AMA3KkbYEGAAAAAAAAAAA - P7ZqR7QvEwE/y6/t4TTtXsBU4fDdyUJYwFXllhPDBrnAWOVjgjbDYkARCD2zEXBfQBEalRIN - 2T9AEPIF3GigE0AKKMBh1blTQCZnlesA0IpAQcjYUhrdCj/x/WXscvNeP/GbYJxtZ8c/8YDQ - kmfJX0Bk4fDdyUJYQGXllhPDBrlAaOVjgjbDYkAhCD2zEXBfQCEalRIN2T9AIPIF3GigEwAA - BAIAAAABAAQACQAAAANkaW0AAAANAAAAAgAAAAMAAAAIAAAEAgAAAAEABAAJAAAACGRpbW5h - bWVzAAAAEwAAAAIAAAAQAAAAAwAEAAkAAAAGbW9kZWwxAAQACQAAAAZtb2RlbDIABAAJAAAA - Bm1vZGVsMwAAABAAAAAIAAQACQAAAAllbHBkX2RpZmYABAAJAAAAB3NlX2RpZmYABAAJAAAA - CWVscGRfd2FpYwAEAAkAAAAMc2VfZWxwZF93YWljAAQACQAAAAZwX3dhaWMABAAJAAAACXNl - X3Bfd2FpYwAEAAkAAAAEd2FpYwAEAAkAAAAHc2Vfd2FpYwAABAIAAAABAAQACQAAAAVjbGFz - cwAAABAAAAADAAQACQAAAAtjb21wYXJlLmxvbwAEAAkAAAAGbWF0cml4AAQACQAAAAVhcnJh - eQAAAP4= + Code + print(comp1) + Output + model elpd_diff se_diff p_worse diag_diff diag_elpd + model1 0.0 0.0 NA + model2 0.0 0.0 NA + +--- -# compare returns expected result (2 models) + WAoAAAACAAQEAgACAwAAAAMTAAAADAAAABAAAAACAAQACQAAAAZtb2RlbDEABAAJAAAABm1v + ZGVsMgAAAA4AAAACAAAAAAAAAADAEDpTX5xF7gAAAA4AAAACAAAAAAAAAAA/tmpHtC8TAQAA + AA4AAAACf/AAAAAAB6I/8AAAAAAAAAAAABAAAAACAAQACQAAAAAABAAJAAAAB04gPCAxMDAA + AAAQAAAAAgAEAAkAAAAAAAQACQAAAAAAAAAOAAAAAsBU4fDdyUJYwFXllhPDBrkAAAAOAAAA + AkARCD2zEXBfQBEalRIN2T8AAAAOAAAAAkAKKMBh1blTQCZnlesA0IoAAAAOAAAAAj/x/WXs + cvNeP/GbYJxtZ8cAAAAOAAAAAkBk4fDdyUJYQGXllhPDBrkAAAAOAAAAAkAhCD2zEXBfQCEa + lRIN2T8AAAQCAAAAAQAEAAkAAAAFbmFtZXMAAAAQAAAADAAEAAkAAAAFbW9kZWwABAAJAAAA + CWVscGRfZGlmZgAEAAkAAAAHc2VfZGlmZgAEAAkAAAAHcF93b3JzZQAEAAkAAAAJZGlhZ19k + aWZmAAQACQAAAAlkaWFnX2VscGQABAAJAAAACWVscGRfd2FpYwAEAAkAAAAMc2VfZWxwZF93 + YWljAAQACQAAAAZwX3dhaWMABAAJAAAACXNlX3Bfd2FpYwAEAAkAAAAEd2FpYwAEAAkAAAAH + c2Vfd2FpYwAABAIAAAABAAQACQAAAAVjbGFzcwAAABAAAAACAAQACQAAAAtjb21wYXJlLmxv + bwAEAAkAAAAKZGF0YS5mcmFtZQAABAIAAAABAAQACQAAAAlyb3cubmFtZXMAAAANAAAAAoAA + AAD////+AAAA/g== + +--- + + Code + print(comp2) + Output + model elpd_diff se_diff p_worse diag_diff diag_elpd + model1 0.0 0.0 NA + model2 -4.1 0.1 1.00 N < 100 + +--- Code - comp1 + print(comp2, p_worse = FALSE) Output - elpd_diff se - 0.0 0.0 + model elpd_diff se_diff + model1 0.0 0.0 + model2 -4.1 0.1 + +# loo_compare returns expected result (3 models) + + WAoAAAACAAQEAgACAwAAAAMTAAAADAAAABAAAAADAAQACQAAAAZtb2RlbDEABAAJAAAABm1v + ZGVsMgAEAAkAAAAGbW9kZWwzAAAADgAAAAMAAAAAAAAAAMAQOlNfnEXuwDANypG2BBgAAAAO + AAAAAwAAAAAAAAAAP7ZqR7QvEwE/y6/t4TTtXgAAAA4AAAADf/AAAAAAB6I/8AAAAAAAAD/w + AAAAAAAAAAAAEAAAAAMABAAJAAAAAAAEAAkAAAAHTiA8IDEwMAAEAAkAAAAHTiA8IDEwMAAA + ABAAAAADAAQACQAAAAAABAAJAAAAAAAEAAkAAAAAAAAADgAAAAPAVOHw3clCWMBV5ZYTwwa5 + wFjlY4I2w2IAAAAOAAAAA0ARCD2zEXBfQBEalRIN2T9AEPIF3GigEwAAAA4AAAADQAoowGHV + uVNAJmeV6wDQikBByNhSGt0KAAAADgAAAAM/8f1l7HLzXj/xm2CcbWfHP/GA0JJnyV8AAAAO + AAAAA0Bk4fDdyUJYQGXllhPDBrlAaOVjgjbDYgAAAA4AAAADQCEIPbMRcF9AIRqVEg3ZP0Ag + 8gXcaKATAAAEAgAAAAEABAAJAAAABW5hbWVzAAAAEAAAAAwABAAJAAAABW1vZGVsAAQACQAA + AAllbHBkX2RpZmYABAAJAAAAB3NlX2RpZmYABAAJAAAAB3Bfd29yc2UABAAJAAAACWRpYWdf + ZGlmZgAEAAkAAAAJZGlhZ19lbHBkAAQACQAAAAllbHBkX3dhaWMABAAJAAAADHNlX2VscGRf + d2FpYwAEAAkAAAAGcF93YWljAAQACQAAAAlzZV9wX3dhaWMABAAJAAAABHdhaWMABAAJAAAA + B3NlX3dhaWMAAAQCAAAAAQAEAAkAAAAFY2xhc3MAAAAQAAAAAgAEAAkAAAALY29tcGFyZS5s + b28ABAAJAAAACmRhdGEuZnJhbWUAAAQCAAAAAQAEAAkAAAAJcm93Lm5hbWVzAAAADQAAAAKA + AAAA/////QAAAP4= --- Code - comp2 + print(comp1) Output - elpd_diff se - -4.1 0.1 + model elpd_diff se_diff p_worse diag_diff diag_elpd + model1 0.0 0.0 NA + model2 -4.1 0.1 1.00 N < 100 + model3 -16.1 0.2 1.00 N < 100 # compare returns expected result (3 models) diff --git a/tests/testthat/test_compare.R b/tests/testthat/test_compare.R index 7f720b22..aba34209 100644 --- a/tests/testthat/test_compare.R +++ b/tests/testthat/test_compare.R @@ -73,8 +73,12 @@ test_that("loo_compare throws appropriate warnings", { comp_colnames <- c( + "model", "elpd_diff", "se_diff", + "p_worse", + "diag_diff", + "diag_elpd", "elpd_waic", "se_elpd_waic", "p_waic", @@ -86,20 +90,31 @@ comp_colnames <- c( test_that("loo_compare returns expected results (2 models)", { comp1 <- loo_compare(w1, w1) expect_s3_class(comp1, "compare.loo") + expect_s3_class(comp1, "data.frame") expect_equal(colnames(comp1), comp_colnames) - expect_equal(rownames(comp1), c("model1", "model2")) - expect_output(print(comp1), "elpd_diff") - expect_equal(comp1[1:2, 1], c(0, 0), ignore_attr = TRUE) - expect_equal(comp1[1:2, 2], c(0, 0), ignore_attr = TRUE) + expect_equal(comp1$model, c("model1", "model2")) + expect_equal(comp1$elpd_diff, c(0, 0), ignore_attr = TRUE) + expect_equal(comp1$se_diff, c(0, 0), ignore_attr = TRUE) + expect_equal(comp1$p_worse, c(NA_real_, NA_real_), ignore_attr = TRUE) + expect_snapshot_value(comp1, style = "serialize") + expect_snapshot(print(comp1)) comp2 <- loo_compare(w1, w2) expect_s3_class(comp2, "compare.loo") expect_equal(colnames(comp2), comp_colnames) - + expect_equal(comp2$p_worse, c(NA, 1)) + expect_equal(comp2$diag_diff, c("", "N < 100")) + expect_equal(comp2$diag_elpd, c("", "")) expect_snapshot_value(comp2, style = "serialize") + expect_snapshot(print(comp2)) + expect_snapshot(print(comp2, p_worse = FALSE)) # specifying objects via ... and via arg x gives equal results expect_equal(comp2, loo_compare(x = list(w1, w2))) + + # custom naming works + comp3 <- loo_compare(x = list("A" = w2, "B" = w1)) + expect_equal(comp3$model, c("B", "A")) }) @@ -108,12 +123,13 @@ test_that("loo_compare returns expected result (3 models)", { comp1 <- loo_compare(w1, w2, w3) expect_equal(colnames(comp1), comp_colnames) - expect_equal(rownames(comp1), c("model1", "model2", "model3")) - expect_equal(comp1[1, 1], 0) + expect_equal(comp1$model, c("model1", "model2", "model3")) + expect_equal(comp1$p_worse, c(NA, 1, 1)) + expect_equal(comp1$diag_diff, c("", "N < 100", "N < 100")) expect_s3_class(comp1, "compare.loo") - expect_s3_class(comp1, "matrix") - + expect_s3_class(comp1, "data.frame") expect_snapshot_value(comp1, style = "serialize") + expect_snapshot(print(comp1)) # specifying objects via '...' gives equivalent results (equal # except rownames) to using 'x' argument @@ -129,13 +145,11 @@ test_that("compare throws deprecation warnings", { test_that("compare returns expected result (2 models)", { expect_warning(comp1 <- loo::compare(w1, w1), "Deprecated") - expect_snapshot(comp1) expect_equal(comp1[1:2], c(elpd_diff = 0, se = 0)) expect_warning(comp2 <- loo::compare(w1, w2), "Deprecated") - expect_snapshot(comp2) - expect_named(comp2, c("elpd_diff", "se")) - expect_s3_class(comp2, "compare.loo") + expect_equal(round(comp2[1:2], 3), c(elpd_diff = -4.057, se = 0.088)) + expect_s3_class(comp2, "old_compare.loo") # specifying objects via ... and via arg x gives equal results expect_warning(comp_via_list <- loo::compare(x = list(w1, w2)), "Deprecated")