diff --git a/DESCRIPTION b/DESCRIPTION index 456cdb54..0f27f9cc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -72,6 +72,7 @@ Remotes: cmu-delphi/epidatasets, cmu-delphi/epidatr, cmu-delphi/epiprocess, + cmu-delphi/epidatasets, dajmcdon/smoothqr Config/testthat/edition: 3 Encoding: UTF-8 diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index c7aebef4..b386fe45 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -186,14 +186,26 @@ arx_fcast_epi_workflow <- function( # --- postprocessor f <- frosting() %>% layer_predict() # %>% layer_naomit() - if (inherits(trainer, "quantile_reg")) { + is_quantile_reg <- inherits(trainer, "quantile_reg") | + (inherits(trainer, "rand_forest") & trainer$engine == "grf_quantiles") + if (is_quantile_reg) { # add all quantile_level to the forecaster and update postprocessor - quantile_levels <- sort(compare_quantile_args( - args_list$quantile_levels, - rlang::eval_tidy(trainer$args$quantile_levels) - )) + if (inherits(trainer, "quantile_reg")) { + quantile_levels <- sort(compare_quantile_args( + args_list$quantile_levels, + rlang::eval_tidy(trainer$args$quantile_levels), + "qr" + )) + trainer$args$quantile_levels <- rlang::enquo(quantile_levels) + } else { + quantile_levels <- sort(compare_quantile_args( + args_list$quantile_levels, + rlang::eval_tidy(trainer$eng_args$quantiles) %||% c(.1, .5, .9), + "grf" + )) + trainer$eng_args$quantiles <- rlang::enquo(quantile_levels) + } args_list$quantile_levels <- quantile_levels - trainer$args$quantile_levels <- rlang::enquo(quantile_levels) f <- f %>% layer_quantile_distn(quantile_levels = quantile_levels) %>% layer_point_from_distn() @@ -345,9 +357,13 @@ print.arx_fcast <- function(x, ...) { NextMethod(name = name, ...) } -compare_quantile_args <- function(alist, tlist) { +compare_quantile_args <- function(alist, tlist, train_method = c("qr", "grf")) { + train_method <- rlang::arg_match(train_method) default_alist <- eval(formals(arx_args_list)$quantile_levels) - default_tlist <- eval(formals(quantile_reg)$quantile_levels) + default_tlist <- switch(train_method, + "qr" = eval(formals(quantile_reg)$quantile_levels), + "grf" = c(.1, .5, .9) + ) if (setequal(alist, default_alist)) { if (setequal(tlist, default_tlist)) { return(sort(unique(union(alist, tlist)))) diff --git a/R/canned-epipred.R b/R/canned-epipred.R index 1e088426..8a5baa58 100644 --- a/R/canned-epipred.R +++ b/R/canned-epipred.R @@ -77,8 +77,9 @@ print.canned_epipred <- function(x, name, ...) { fn_meta <- function() { cli::cli_ul() cli::cli_li("Geography: {.field {x$metadata$training$geo_type}},") - if (!is.null(x$metadata$training$other_keys)) { - cli::cli_li("Other keys: {.field {x$metadata$training$other_keys}},") + other_keys <- x$metadata$training$other_keys + if (!is.null(other_keys) && length(other_keys) > 0L) { + cli::cli_li("Other keys: {.field {other_keys}},") } cli::cli_li("Time type: {.field {x$metadata$training$time_type}},") cli::cli_li("Using data up-to-date as of: {.field {format(x$metadata$training$as_of)}}.") diff --git a/man/step_adjust_latency.Rd b/man/step_adjust_latency.Rd index 1a677042..0078de10 100644 --- a/man/step_adjust_latency.Rd +++ b/man/step_adjust_latency.Rd @@ -267,8 +267,8 @@ while this will not: \if{html}{\out{
}}\preformatted{toy_recipe <- epi_recipe(toy_df) \%>\% step_epi_lag(a, lag=0) \%>\% step_adjust_latency(a, method = "extend_lags") -#> Warning: If `method` is "extend_lags" or "locf", then the previous -#> `step_epi_lag`s won't work with modified data. +#> Warning: If `method` is "extend_lags" or "locf", then the previous `step_epi_lag`s won't +#> work with modified data. }\if{html}{\out{
}} If you create columns that you then apply lags to (such as diff --git a/tests/testthat/_snaps/arx_args_list.md b/tests/testthat/_snaps/arx_args_list.md index 959a5e25..2579c5f0 100644 --- a/tests/testthat/_snaps/arx_args_list.md +++ b/tests/testthat/_snaps/arx_args_list.md @@ -124,6 +124,15 @@ # arx forecaster disambiguates quantiles + Code + compare_quantile_args(alist / 10, 1:9 / 10, "grf") + Condition + Error in `compare_quantile_args()`: + ! You have specified different, non-default, quantiles in the trainier and `arx_args` options. + i Please only specify quantiles in one location. + +--- + Code compare_quantile_args(alist, tlist) Condition diff --git a/tests/testthat/_snaps/snapshots.md b/tests/testthat/_snaps/snapshots.md index a03a8dd4..f3e7e573 100644 --- a/tests/testthat/_snaps/snapshots.md +++ b/tests/testthat/_snaps/snapshots.md @@ -1093,7 +1093,6 @@ Training data was an with: * Geography: state, - * Other keys: , * Time type: day, * Using data up-to-date as of: 2022-05-31. * With the last data available on 2021-12-31 @@ -1117,7 +1116,6 @@ Training data was an with: * Geography: state, - * Other keys: , * Time type: day, * Using data up-to-date as of: 2022-05-31. * With the last data available on 2021-12-31 @@ -1142,7 +1140,6 @@ Training data was an with: * Geography: state, - * Other keys: , * Time type: day, * Using data up-to-date as of: 2022-05-31. * With the last data available on 2021-12-31 diff --git a/tests/testthat/test-arx_args_list.R b/tests/testthat/test-arx_args_list.R index 9bbff013..226444e3 100644 --- a/tests/testthat/test-arx_args_list.R +++ b/tests/testthat/test-arx_args_list.R @@ -43,6 +43,11 @@ test_that("arx forecaster disambiguates quantiles", { compare_quantile_args(alist, tlist), sort(c(alist, tlist)) ) + expect_snapshot( + error = TRUE, + compare_quantile_args(alist / 10, 1:9 / 10, "grf") + ) + expect_identical(compare_quantile_args(alist, 1:9 / 10, "grf"), 1:9 / 10) alist <- c(.5, alist) expect_identical( # tlist is default, should give alist compare_quantile_args(alist, tlist), diff --git a/tests/testthat/test-grf_quantiles.R b/tests/testthat/test-grf_quantiles.R index e27a0ac6..e2cf90cf 100644 --- a/tests/testthat/test-grf_quantiles.R +++ b/tests/testthat/test-grf_quantiles.R @@ -51,12 +51,30 @@ test_that("quantile_rand_forest handles allows setting the trees and mtry", { expect_identical(pars$`_num_trees`, manual$`_num_trees`) }) -test_that("quantile_rand_forest predicts reasonable quantiles", { +test_that("quantile_rand_forest operates with arx_forecaster", { spec <- rand_forest(mode = "regression") %>% - set_engine("grf_quantiles", quantiles = c(.2, .5, .8)) - expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib)) - # swapping around the probabilities, because somehow this happens in practice, - # but I'm not sure how to reproduce - out$fit$quantiles.orig <- c(0.5, 0.9, 0.1) - expect_no_error(predict(out, tib)) + set_engine("grf_quantiles", quantiles = c(.1, .2, .5, .8, .9)) # non-default + expect_identical(rlang::eval_tidy(spec$eng_args$quantiles), c(.1, .2, .5, .8, .9)) + tib <- as_epi_df(tibble(time_value = 1:25, geo_value = "ca", value = rnorm(25))) + o <- arx_fcast_epi_workflow(tib, "value", trainer = spec) + spec2 <- parsnip::extract_spec_parsnip(o) + expect_identical( + rlang::eval_tidy(spec2$eng_args$quantiles), + rlang::eval_tidy(spec$eng_args$quantiles) + ) + spec <- rand_forest(mode = "regression", "grf_quantiles") + expect_null(rlang::eval_tidy(spec$eng_args)) + o <- arx_fcast_epi_workflow(tib, "value", trainer = spec) + spec2 <- parsnip::extract_spec_parsnip(o) + expect_identical( + rlang::eval_tidy(spec2$eng_args$quantiles), + c(.05, .1, .5, .9, .95) # merged with arx_args default + ) + df <- epidatasets::counts_subset %>% filter(time_value >= "2021-10-01") + + z <- arx_forecaster(df, "cases", "cases", spec2) + expect_identical( + nested_quantiles(z$predictions$.pred_distn[1])[[1]]$quantile_levels, + c(.05, .1, .5, .9, .95) + ) })