diff --git a/DESCRIPTION b/DESCRIPTION index 07e86da4..2166122a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.2.5 +Version: 0.2.6 Authors@R: c( person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), diff --git a/NEWS.md b/NEWS.md index 38adef1b..cf4c2994 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,11 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicate PR's. +# epipredict 0.2.6 + +- `arx_forecaster()` and `flatline_forecaster()` now error early when `quantile_by_key` contains columns that are not keys of the input `epi_df`, rather than silently dropping the invalid keys (#229). +- `arx_forecaster()` now warns when `quantile_by_key` is supplied with a quantile-output trainer (`quantile_reg()`, `rand_forest()` with engine `"grf_quantiles"`), where the argument would otherwise be silently ignored (#229). + # epipredict 0.2.5 - Fix `arx_forecaster()` and `arx_fcast_epi_workflow()` so that the error raised when `forecast_date + ahead != target_date` reports the actual validation message rather than a cryptic `cli` template-evaluation error (#473). diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 4252b2b0..bc479143 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -127,6 +127,20 @@ arx_fcast_epi_workflow <- function( if (!(is.null(trainer) || is_regression(trainer))) { cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") } + if (length(args_list$quantile_by_key) > 0L) { + valid_keys <- key_colnames(epi_data) + missing_keys <- setdiff(args_list$quantile_by_key, valid_keys) + if (length(missing_keys) > 0L) { + cli_abort( + c( + "Some {.arg quantile_by_key} columns are not key columns of the input {.cls epi_df}.", + "!" = "Missing: {.val {missing_keys}}.", + i = "Available keys: {.val {valid_keys}}." + ), + class = "epipredict__arx_forecaster__quantile_by_key_invalid" + ) + } + } # forecast_date is above all what they set; # if they don't and they're not adjusting latency, it defaults to the max time_value # if they're adjusting, it defaults to the as_of @@ -194,6 +208,17 @@ arx_fcast_epi_workflow <- function( f <- frosting() %>% layer_predict() # %>% layer_naomit() is_quantile_reg <- inherits(trainer, "quantile_reg") | (inherits(trainer, "rand_forest") & trainer$engine == "grf_quantiles") + if (is_quantile_reg && length(args_list$quantile_by_key) > 0L) { + cli_warn( + paste0( + "{.arg quantile_by_key} (set to {.val {args_list$quantile_by_key}}) ", + "has no effect when the trainer produces quantile distributions ", + "directly (e.g., {.fn quantile_reg}, {.fn rand_forest} with engine ", + "{.val grf_quantiles}). The argument is being ignored." + ), + class = "epipredict__arx_forecaster__quantile_by_key_ignored" + ) + } if (is_quantile_reg) { # add all quantile_level to the forecaster and update postprocessor if (inherits(trainer, "quantile_reg")) { diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index 617d703e..2f6d4d2e 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -59,6 +59,20 @@ flatline_forecaster <- function( if (!inherits(args_list, c("flat_fcast", "alist"))) { cli_abort("`args_list` was not created using `flatline_args_list()`.") } + if (length(args_list$quantile_by_key) > 0L) { + valid_keys <- key_colnames(epi_data) + missing_keys <- setdiff(args_list$quantile_by_key, valid_keys) + if (length(missing_keys) > 0L) { + cli_abort( + c( + "Some {.arg quantile_by_key} columns are not key columns of the input {.cls epi_df}.", + "!" = "Missing: {.val {missing_keys}}.", + i = "Available keys: {.val {valid_keys}}." + ), + class = "epipredict__flatline_forecaster__quantile_by_key_invalid" + ) + } + } keys <- key_colnames(epi_data) ek <- kill_time_value(keys) outcome <- rlang::sym(outcome) diff --git a/tests/testthat/test-arx_forecaster.R b/tests/testthat/test-arx_forecaster.R index be087297..3a14109f 100644 --- a/tests/testthat/test-arx_forecaster.R +++ b/tests/testthat/test-arx_forecaster.R @@ -44,6 +44,27 @@ test_that("warns if there's not enough data to predict", { ) }) +test_that("arx_forecaster errors on invalid quantile_by_key columns (issue #229)", { + jhu <- epidatasets::covid_case_death_rates + expect_error( + arx_forecaster(jhu, "death_rate", c("death_rate"), + args_list = arx_args_list(quantile_by_key = "nonexistent_column") + ), + class = "epipredict__arx_forecaster__quantile_by_key_invalid" + ) +}) + +test_that("arx_forecaster warns when quantile_by_key is used with quantile_reg trainer (issue #229)", { + jhu <- epidatasets::covid_case_death_rates + expect_warning( + arx_forecaster(jhu, "death_rate", c("death_rate"), + trainer = quantile_reg(), + args_list = arx_args_list(quantile_by_key = "geo_value") + ), + class = "epipredict__arx_forecaster__quantile_by_key_ignored" + ) +}) + test_that("arx_forecaster errors with documented class when forecast_date + ahead != target_date (issue #473)", { df <- tibble( geo_value = "ri", diff --git a/tests/testthat/test-flatline_forecaster.R b/tests/testthat/test-flatline_forecaster.R index f6d2de24..d6e73147 100644 --- a/tests/testthat/test-flatline_forecaster.R +++ b/tests/testthat/test-flatline_forecaster.R @@ -20,3 +20,13 @@ test_that("flatline_forecaster returns one prediction per geo with trailing NAs counts <- res$predictions %>% dplyr::count(geo_value, target_date) expect_true(all(counts$n == 1L)) }) + +test_that("flatline_forecaster errors on invalid quantile_by_key columns (issue #229)", { + jhu <- epidatasets::covid_case_death_rates + expect_error( + flatline_forecaster(jhu, "death_rate", + flatline_args_list(quantile_by_key = "nonexistent_column") + ), + class = "epipredict__flatline_forecaster__quantile_by_key_invalid" + ) +})