|
| 1 | +prob_arx <- function(x, y, geo_value, time_value, lags = c(0, 7, 14), |
| 2 | + ahead = 7, min_train_window = 20, lower_level = 0.05, |
| 3 | + upper_level = 0.95, symmetrize = TRUE, nonneg = TRUE) { |
| 4 | + # Return NA if insufficient training data |
| 5 | + if (length(y) < min_train_window + max(lags) + ahead) { |
| 6 | + return(data.frame(point = NA, lower = NA, upper = NA)) |
| 7 | + } |
| 8 | + |
| 9 | + # Useful transformations |
| 10 | + if (!missing(x)) x <- data.frame(x, y) |
| 11 | + else x <- data.frame(y) |
| 12 | + if (!is.list(lags)) lags <- list(lags) |
| 13 | + lags = rep(lags, length.out = ncol(x)) |
| 14 | + |
| 15 | + # Build features and response for the AR model, and then fit it |
| 16 | + dat <- do.call( |
| 17 | + data.frame, |
| 18 | + unlist( # Below we loop through and build the lagged features |
| 19 | + purrr::map(1:ncol(x), function(i) { |
| 20 | + purrr::map(lags[[i]], function(lag) dplyr::lag(x[,i], n = lag)) |
| 21 | + }), |
| 22 | + recursive = FALSE |
| 23 | + ) |
| 24 | + ) |
| 25 | + dat$y <- dplyr::lead(y, n = ahead) |
| 26 | + obj <- lm(y ~ ., data = dat) |
| 27 | + |
| 28 | + # Use LOCF to fill NAs in the latest feature values, make a prediction |
| 29 | + data.table::setnafill(dat, type = "locf") |
| 30 | + dat <- cbind(dat, data.frame(geo_value, time_value)) |
| 31 | + point <- predict(obj, newdata = dat %>% |
| 32 | + dplyr::group_by(geo_value) %>% |
| 33 | + dplyr::filter(time_value == max(time_value))) |
| 34 | + |
| 35 | + # Compute a band |
| 36 | + r <- residuals(obj) |
| 37 | + s <- ifelse(symmetrize, -1, NA) # Should the residuals be symmetrized? |
| 38 | + q <- quantile(c(r, s * r), probs = c(lower_level, upper_level), na.rm = TRUE) |
| 39 | + lower <- point + q[1] |
| 40 | + upper <- point + q[2] |
| 41 | + |
| 42 | + # Clip at zero if we need to, then return |
| 43 | + if (nonneg) { |
| 44 | + point = pmax(point, 0) |
| 45 | + lower = pmax(lower, 0) |
| 46 | + upper = pmax(upper, 0) |
| 47 | + } |
| 48 | + return(data.frame(geo_value = unique(geo_value), # Must include geo value! |
| 49 | + point = point, lower = lower, upper = upper)) |
| 50 | +} |
0 commit comments