diff --git a/DESCRIPTION b/DESCRIPTION
index 456cdb54..0f27f9cc 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -72,6 +72,7 @@ Remotes:
cmu-delphi/epidatasets,
cmu-delphi/epidatr,
cmu-delphi/epiprocess,
+ cmu-delphi/epidatasets,
dajmcdon/smoothqr
Config/testthat/edition: 3
Encoding: UTF-8
diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R
index c7aebef4..b386fe45 100644
--- a/R/arx_forecaster.R
+++ b/R/arx_forecaster.R
@@ -186,14 +186,26 @@ arx_fcast_epi_workflow <- function(
# --- postprocessor
f <- frosting() %>% layer_predict() # %>% layer_naomit()
- if (inherits(trainer, "quantile_reg")) {
+ is_quantile_reg <- inherits(trainer, "quantile_reg") |
+ (inherits(trainer, "rand_forest") & trainer$engine == "grf_quantiles")
+ if (is_quantile_reg) {
# add all quantile_level to the forecaster and update postprocessor
- quantile_levels <- sort(compare_quantile_args(
- args_list$quantile_levels,
- rlang::eval_tidy(trainer$args$quantile_levels)
- ))
+ if (inherits(trainer, "quantile_reg")) {
+ quantile_levels <- sort(compare_quantile_args(
+ args_list$quantile_levels,
+ rlang::eval_tidy(trainer$args$quantile_levels),
+ "qr"
+ ))
+ trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
+ } else {
+ quantile_levels <- sort(compare_quantile_args(
+ args_list$quantile_levels,
+ rlang::eval_tidy(trainer$eng_args$quantiles) %||% c(.1, .5, .9),
+ "grf"
+ ))
+ trainer$eng_args$quantiles <- rlang::enquo(quantile_levels)
+ }
args_list$quantile_levels <- quantile_levels
- trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
f <- f %>%
layer_quantile_distn(quantile_levels = quantile_levels) %>%
layer_point_from_distn()
@@ -345,9 +357,13 @@ print.arx_fcast <- function(x, ...) {
NextMethod(name = name, ...)
}
-compare_quantile_args <- function(alist, tlist) {
+compare_quantile_args <- function(alist, tlist, train_method = c("qr", "grf")) {
+ train_method <- rlang::arg_match(train_method)
default_alist <- eval(formals(arx_args_list)$quantile_levels)
- default_tlist <- eval(formals(quantile_reg)$quantile_levels)
+ default_tlist <- switch(train_method,
+ "qr" = eval(formals(quantile_reg)$quantile_levels),
+ "grf" = c(.1, .5, .9)
+ )
if (setequal(alist, default_alist)) {
if (setequal(tlist, default_tlist)) {
return(sort(unique(union(alist, tlist))))
diff --git a/R/canned-epipred.R b/R/canned-epipred.R
index 1e088426..8a5baa58 100644
--- a/R/canned-epipred.R
+++ b/R/canned-epipred.R
@@ -77,8 +77,9 @@ print.canned_epipred <- function(x, name, ...) {
fn_meta <- function() {
cli::cli_ul()
cli::cli_li("Geography: {.field {x$metadata$training$geo_type}},")
- if (!is.null(x$metadata$training$other_keys)) {
- cli::cli_li("Other keys: {.field {x$metadata$training$other_keys}},")
+ other_keys <- x$metadata$training$other_keys
+ if (!is.null(other_keys) && length(other_keys) > 0L) {
+ cli::cli_li("Other keys: {.field {other_keys}},")
}
cli::cli_li("Time type: {.field {x$metadata$training$time_type}},")
cli::cli_li("Using data up-to-date as of: {.field {format(x$metadata$training$as_of)}}.")
diff --git a/man/step_adjust_latency.Rd b/man/step_adjust_latency.Rd
index 1a677042..0078de10 100644
--- a/man/step_adjust_latency.Rd
+++ b/man/step_adjust_latency.Rd
@@ -267,8 +267,8 @@ while this will not:
\if{html}{\out{
}}\preformatted{toy_recipe <- epi_recipe(toy_df) \%>\%
step_epi_lag(a, lag=0) \%>\%
step_adjust_latency(a, method = "extend_lags")
-#> Warning: If `method` is "extend_lags" or "locf", then the previous
-#> `step_epi_lag`s won't work with modified data.
+#> Warning: If `method` is "extend_lags" or "locf", then the previous `step_epi_lag`s won't
+#> work with modified data.
}\if{html}{\out{
}}
If you create columns that you then apply lags to (such as
diff --git a/tests/testthat/_snaps/arx_args_list.md b/tests/testthat/_snaps/arx_args_list.md
index 959a5e25..2579c5f0 100644
--- a/tests/testthat/_snaps/arx_args_list.md
+++ b/tests/testthat/_snaps/arx_args_list.md
@@ -124,6 +124,15 @@
# arx forecaster disambiguates quantiles
+ Code
+ compare_quantile_args(alist / 10, 1:9 / 10, "grf")
+ Condition
+ Error in `compare_quantile_args()`:
+ ! You have specified different, non-default, quantiles in the trainier and `arx_args` options.
+ i Please only specify quantiles in one location.
+
+---
+
Code
compare_quantile_args(alist, tlist)
Condition
diff --git a/tests/testthat/_snaps/snapshots.md b/tests/testthat/_snaps/snapshots.md
index a03a8dd4..f3e7e573 100644
--- a/tests/testthat/_snaps/snapshots.md
+++ b/tests/testthat/_snaps/snapshots.md
@@ -1093,7 +1093,6 @@
Training data was an with:
* Geography: state,
- * Other keys: ,
* Time type: day,
* Using data up-to-date as of: 2022-05-31.
* With the last data available on 2021-12-31
@@ -1117,7 +1116,6 @@
Training data was an with:
* Geography: state,
- * Other keys: ,
* Time type: day,
* Using data up-to-date as of: 2022-05-31.
* With the last data available on 2021-12-31
@@ -1142,7 +1140,6 @@
Training data was an with:
* Geography: state,
- * Other keys: ,
* Time type: day,
* Using data up-to-date as of: 2022-05-31.
* With the last data available on 2021-12-31
diff --git a/tests/testthat/test-arx_args_list.R b/tests/testthat/test-arx_args_list.R
index 9bbff013..226444e3 100644
--- a/tests/testthat/test-arx_args_list.R
+++ b/tests/testthat/test-arx_args_list.R
@@ -43,6 +43,11 @@ test_that("arx forecaster disambiguates quantiles", {
compare_quantile_args(alist, tlist),
sort(c(alist, tlist))
)
+ expect_snapshot(
+ error = TRUE,
+ compare_quantile_args(alist / 10, 1:9 / 10, "grf")
+ )
+ expect_identical(compare_quantile_args(alist, 1:9 / 10, "grf"), 1:9 / 10)
alist <- c(.5, alist)
expect_identical( # tlist is default, should give alist
compare_quantile_args(alist, tlist),
diff --git a/tests/testthat/test-grf_quantiles.R b/tests/testthat/test-grf_quantiles.R
index e27a0ac6..e2cf90cf 100644
--- a/tests/testthat/test-grf_quantiles.R
+++ b/tests/testthat/test-grf_quantiles.R
@@ -51,12 +51,30 @@ test_that("quantile_rand_forest handles allows setting the trees and mtry", {
expect_identical(pars$`_num_trees`, manual$`_num_trees`)
})
-test_that("quantile_rand_forest predicts reasonable quantiles", {
+test_that("quantile_rand_forest operates with arx_forecaster", {
spec <- rand_forest(mode = "regression") %>%
- set_engine("grf_quantiles", quantiles = c(.2, .5, .8))
- expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib))
- # swapping around the probabilities, because somehow this happens in practice,
- # but I'm not sure how to reproduce
- out$fit$quantiles.orig <- c(0.5, 0.9, 0.1)
- expect_no_error(predict(out, tib))
+ set_engine("grf_quantiles", quantiles = c(.1, .2, .5, .8, .9)) # non-default
+ expect_identical(rlang::eval_tidy(spec$eng_args$quantiles), c(.1, .2, .5, .8, .9))
+ tib <- as_epi_df(tibble(time_value = 1:25, geo_value = "ca", value = rnorm(25)))
+ o <- arx_fcast_epi_workflow(tib, "value", trainer = spec)
+ spec2 <- parsnip::extract_spec_parsnip(o)
+ expect_identical(
+ rlang::eval_tidy(spec2$eng_args$quantiles),
+ rlang::eval_tidy(spec$eng_args$quantiles)
+ )
+ spec <- rand_forest(mode = "regression", "grf_quantiles")
+ expect_null(rlang::eval_tidy(spec$eng_args))
+ o <- arx_fcast_epi_workflow(tib, "value", trainer = spec)
+ spec2 <- parsnip::extract_spec_parsnip(o)
+ expect_identical(
+ rlang::eval_tidy(spec2$eng_args$quantiles),
+ c(.05, .1, .5, .9, .95) # merged with arx_args default
+ )
+ df <- epidatasets::counts_subset %>% filter(time_value >= "2021-10-01")
+
+ z <- arx_forecaster(df, "cases", "cases", spec2)
+ expect_identical(
+ nested_quantiles(z$predictions$.pred_distn[1])[[1]]$quantile_levels,
+ c(.05, .1, .5, .9, .95)
+ )
})