1616# ' inverting the existing scaling.
1717# ' @param by A (possibly named) character vector of variables to join by.
1818# '
19- # ' If `NULL`, the default, the function will perform a natural join, using all
20- # ' variables in common across the `epi_df` produced by the `predict()` call
21- # ' and the user-provided dataset.
22- # ' If columns in that `epi_df` and `df` have the same name (and aren't
23- # ' included in `by`), `.df` is added to the one from the user-provided data
24- # ' to disambiguate.
19+ # ' If `NULL`, the default, the function will try to infer a reasonable set of
20+ # ' columns. First, it will try to join by all variables in the training/test
21+ # ' data with roles `"geo_value"`, `"key"`, or `"time_value"` that also appear in
22+ # ' `df`; these roles are automatically set if you are using an `epi_df`, or you
23+ # ' can use, e.g., `update_role`. If no such roles are set, it will try to
24+ # ' perform a natural join, using variables in common between the training/test
25+ # ' data and population data.
26+ # '
27+ # ' If columns in the training/testing data and `df` have the same name (and
28+ # ' aren't included in `by`), a `.df` suffix is added to the one from the
29+ # ' user-provided data to disambiguate.
2530# '
2631# ' To join by different variables on the `epi_df` and `df`, use a named vector.
2732# ' For example, `by = c("geo_value" = "states")` will match `epi_df$geo_value`
2833# ' to `df$states`. To join by multiple variables, use a vector with length > 1.
2934# ' For example, `by = c("geo_value" = "states", "county" = "county")` will match
3035# ' `epi_df$geo_value` to `df$states` and `epi_df$county` to `df$county`.
3136# '
32- # ' See [dplyr::left_join ()] for more details.
37+ # ' See [dplyr::inner_join ()] for more details.
3338# ' @param df_pop_col the name of the column in the data frame `df` that
3439# ' contains the population data and will be used for scaling.
3540# ' This should be one column.
@@ -89,13 +94,25 @@ step_population_scaling <-
8994 suffix = " _scaled" ,
9095 skip = FALSE ,
9196 id = rand_id(" population_scaling" )) {
92- arg_is_scalar(role , df_pop_col , rate_rescaling , create_new , suffix , id )
93- arg_is_lgl(create_new , skip )
94- arg_is_chr(df_pop_col , suffix , id )
97+ if (rlang :: dots_n(... ) == 0L ) {
98+ cli_abort(c(
99+ " `...` must not be empty." ,
100+ " >" = " Please provide one or more tidyselect expressions in `...`
101+ specifying the columns to which scaling should be applied." ,
102+ " >" = " If you really want to list `step_population_scaling` in your
103+ recipe but not have it do anything, you can use a tidyselection
104+ that selects zero variables, such as `c()`."
105+ ))
106+ }
107+ arg_is_scalar(role , df_pop_col , rate_rescaling , create_new , suffix , skip , id )
108+ arg_is_chr(role , df_pop_col , suffix , id )
109+ hardhat :: validate_column_names(df , df_pop_col )
95110 arg_is_chr(by , allow_null = TRUE )
111+ arg_is_numeric(rate_rescaling )
96112 if (rate_rescaling < = 0 ) {
97113 cli_abort(" `rate_rescaling` must be a positive number." )
98114 }
115+ arg_is_lgl(create_new , skip )
99116
100117 recipes :: add_step(
101118 recipe ,
@@ -138,6 +155,42 @@ step_population_scaling_new <-
138155
139156# ' @export
140157prep.step_population_scaling <- function (x , training , info = NULL , ... ) {
158+ if (is.null(x $ by )) {
159+ rhs_potential_keys <- setdiff(colnames(x $ df ), x $ df_pop_col )
160+ lhs_potential_keys <- info %> %
161+ filter(role %in% c(" geo_value" , " key" , " time_value" )) %> %
162+ extract2(" variable" ) %> %
163+ unique() # in case of weird var with multiple of above roles
164+ if (length(lhs_potential_keys ) == 0L ) {
165+ # We're working with a recipe and tibble, and *_role hasn't set up any of
166+ # the above roles. Let's say any column could actually act as a key, and
167+ # lean on `intersect` below to make this something reasonable.
168+ lhs_potential_keys <- names(training )
169+ }
170+ suggested_min_keys <- info %> %
171+ filter(role %in% c(" geo_value" , " key" )) %> %
172+ extract2(" variable" ) %> %
173+ unique()
174+ # (0 suggested keys if we weren't given any epikeytime var info.)
175+ x $ by <- intersect(lhs_potential_keys , rhs_potential_keys )
176+ if (length(x $ by ) == 0L ) {
177+ cli_stop(c(
178+ " Couldn't guess a default for `by`" ,
179+ " >" = " Please rename columns in your population data to match those in your training data,
180+ or manually specify `by =` in `step_population_scaling()`."
181+ ), class = " epipredict__step_population_scaling__default_by_no_intersection" )
182+ }
183+ if (! all(suggested_min_keys %in% x $ by )) {
184+ cli_warn(c(
185+ " {setdiff(suggested_min_keys, x$by)} {?was an/were} epikey column{?s} in the training data,
186+ but {?wasn't/weren't} found in the population `df`." ,
187+ " i" = " Defaulting to join by {x$by}." ,
188+ " >" = " Double-check whether column names on the population `df` match those for your training data." ,
189+ " >" = " Consider using population data with breakdowns by {suggested_min_keys}." ,
190+ " >" = " Manually specify `by =` to silence."
191+ ), class = " epipredict__step_population_scaling__default_by_missing_suggested_keys" )
192+ }
193+ }
141194 step_population_scaling_new(
142195 terms = x $ terms ,
143196 role = x $ role ,
@@ -156,10 +209,14 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) {
156209
157210# ' @export
158211bake.step_population_scaling <- function (object , new_data , ... ) {
159- object $ by <- object $ by %|| % intersect(
160- epi_keys_only(new_data ),
161- colnames(select(object $ df , ! object $ df_pop_col ))
162- )
212+ if (is.null(object $ by )) {
213+ cli :: cli_abort(c(
214+ " `by` was not set and no default was filled in" ,
215+ " >" = " If this was a fit recipe generated from an older version
216+ of epipredict that you loaded in from a file,
217+ please regenerate with the current version of epipredict."
218+ ))
219+ }
163220 joinby <- list (x = names(object $ by ) %|| % object $ by , y = object $ by )
164221 hardhat :: validate_column_names(new_data , joinby $ x )
165222 hardhat :: validate_column_names(object $ df , joinby $ y )
@@ -177,7 +234,10 @@ bake.step_population_scaling <- function(object, new_data, ...) {
177234 suffix <- ifelse(object $ create_new , object $ suffix , " " )
178235 col_to_remove <- setdiff(colnames(object $ df ), colnames(new_data ))
179236
180- left_join(new_data , object $ df , by = object $ by , suffix = c(" " , " .df" )) %> %
237+ inner_join(new_data , object $ df ,
238+ by = object $ by , relationship = " many-to-one" , unmatched = c(" error" , " drop" ),
239+ suffix = c(" " , " .df" )
240+ ) %> %
181241 mutate(
182242 across(
183243 all_of(object $ columns ),
0 commit comments