Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.2.5
Version: 0.2.6
Authors@R: c(
person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"),
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicate PR's.

# epipredict 0.2.6

- `arx_forecaster()` and `flatline_forecaster()` now error early when `quantile_by_key` contains columns that are not keys of the input `epi_df`, rather than silently dropping the invalid keys (#229).
- `arx_forecaster()` now warns when `quantile_by_key` is supplied with a quantile-output trainer (`quantile_reg()`, `rand_forest()` with engine `"grf_quantiles"`), where the argument would otherwise be silently ignored (#229).

# epipredict 0.2.5

- Fix `arx_forecaster()` and `arx_fcast_epi_workflow()` so that the error raised when `forecast_date + ahead != target_date` reports the actual validation message rather than a cryptic `cli` template-evaluation error (#473).
Expand Down
25 changes: 25 additions & 0 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,20 @@ arx_fcast_epi_workflow <- function(
if (!(is.null(trainer) || is_regression(trainer))) {
cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.")
}
if (length(args_list$quantile_by_key) > 0L) {
valid_keys <- key_colnames(epi_data)
missing_keys <- setdiff(args_list$quantile_by_key, valid_keys)
if (length(missing_keys) > 0L) {
cli_abort(
c(
"Some {.arg quantile_by_key} columns are not key columns of the input {.cls epi_df}.",
"!" = "Missing: {.val {missing_keys}}.",
i = "Available keys: {.val {valid_keys}}."
),
class = "epipredict__arx_forecaster__quantile_by_key_invalid"
)
}
}
# forecast_date is above all what they set;
# if they don't and they're not adjusting latency, it defaults to the max time_value
# if they're adjusting, it defaults to the as_of
Expand Down Expand Up @@ -194,6 +208,17 @@ arx_fcast_epi_workflow <- function(
f <- frosting() %>% layer_predict() # %>% layer_naomit()
is_quantile_reg <- inherits(trainer, "quantile_reg") |
(inherits(trainer, "rand_forest") & trainer$engine == "grf_quantiles")
if (is_quantile_reg && length(args_list$quantile_by_key) > 0L) {
cli_warn(
paste0(
"{.arg quantile_by_key} (set to {.val {args_list$quantile_by_key}}) ",
"has no effect when the trainer produces quantile distributions ",
"directly (e.g., {.fn quantile_reg}, {.fn rand_forest} with engine ",
"{.val grf_quantiles}). The argument is being ignored."
),
class = "epipredict__arx_forecaster__quantile_by_key_ignored"
)
}
if (is_quantile_reg) {
# add all quantile_level to the forecaster and update postprocessor
if (inherits(trainer, "quantile_reg")) {
Expand Down
14 changes: 14 additions & 0 deletions R/flatline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ flatline_forecaster <- function(
if (!inherits(args_list, c("flat_fcast", "alist"))) {
cli_abort("`args_list` was not created using `flatline_args_list()`.")
}
if (length(args_list$quantile_by_key) > 0L) {
valid_keys <- key_colnames(epi_data)
missing_keys <- setdiff(args_list$quantile_by_key, valid_keys)
if (length(missing_keys) > 0L) {
cli_abort(
c(
"Some {.arg quantile_by_key} columns are not key columns of the input {.cls epi_df}.",
"!" = "Missing: {.val {missing_keys}}.",
i = "Available keys: {.val {valid_keys}}."
),
class = "epipredict__flatline_forecaster__quantile_by_key_invalid"
)
}
}
keys <- key_colnames(epi_data)
ek <- kill_time_value(keys)
outcome <- rlang::sym(outcome)
Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test-arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,27 @@ test_that("warns if there's not enough data to predict", {
)
})

test_that("arx_forecaster errors on invalid quantile_by_key columns (issue #229)", {
jhu <- epidatasets::covid_case_death_rates
expect_error(
arx_forecaster(jhu, "death_rate", c("death_rate"),
args_list = arx_args_list(quantile_by_key = "nonexistent_column")
),
class = "epipredict__arx_forecaster__quantile_by_key_invalid"
)
})

test_that("arx_forecaster warns when quantile_by_key is used with quantile_reg trainer (issue #229)", {
jhu <- epidatasets::covid_case_death_rates
expect_warning(
arx_forecaster(jhu, "death_rate", c("death_rate"),
trainer = quantile_reg(),
args_list = arx_args_list(quantile_by_key = "geo_value")
),
class = "epipredict__arx_forecaster__quantile_by_key_ignored"
)
})

test_that("arx_forecaster errors with documented class when forecast_date + ahead != target_date (issue #473)", {
df <- tibble(
geo_value = "ri",
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/test-flatline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,13 @@ test_that("flatline_forecaster returns one prediction per geo with trailing NAs
counts <- res$predictions %>% dplyr::count(geo_value, target_date)
expect_true(all(counts$n == 1L))
})

test_that("flatline_forecaster errors on invalid quantile_by_key columns (issue #229)", {
jhu <- epidatasets::covid_case_death_rates
expect_error(
flatline_forecaster(jhu, "death_rate",
flatline_args_list(quantile_by_key = "nonexistent_column")
),
class = "epipredict__flatline_forecaster__quantile_by_key_invalid"
)
})
Loading