Skip to content

Commit

Permalink
add cdc flat forecast
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 committed Oct 31, 2024
1 parent 559b4b5 commit 6b3f264
Showing 1 changed file with 85 additions and 4 deletions.
89 changes: 85 additions & 4 deletions nssp_demo/timeseries_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,38 @@ script_packages <- c(
"fable",
"jsonlite",
"argparser",
"arrow"
"arrow",
"pak",
"glue"
)

script_pak_packages <- c(
"epipredict",
"epiprocess"
)
##

## load in packages without messages
purrr::walk(script_packages, \(pkg) {
suppressPackageStartupMessages(
library(pkg, character.only = TRUE)
)
})

## Load packages from the cmu-delphi repo if required
purrr::walk(script_pak_packages, \(pkg) {
if (pkg %in% rownames(installed.packages())) {
suppressPackageStartupMessages(library(pkg,
character.only = TRUE
))
} else {
suppressMessages(pak::pkg_install(glue("cmu-delphi/{pkg}@main")))
suppressPackageStartupMessages(library(pkg,
character.only = TRUE
))
}
})

#' Fit and Forecast Time Series Data
#'
#' This function fits a combination ensemble model to the training data and
Expand Down Expand Up @@ -46,6 +68,7 @@ fit_and_forecast <- function(data,
output_sym <- rlang::sym(output_col)
fit <-
data |>
as_tsibble(index = date) |>
filter(data_type == "train") |>
model(
comb_model = combination_ensemble(
Expand All @@ -63,6 +86,48 @@ fit_and_forecast <- function(data,
forecast_samples
}

#' Generate CDC Flat Forecast
#'
#' This function generates a CDC flat forecast for the given data and returns
#' a data frame containing the forecasted values with columns for quantile
#' levels, reference dates, and target end dates suitable for use with
#' `scoringutils`.
#'
#' @param data A data frame containing the input data.
#' @param target_col A string specifying the column name of the target variable
#' in the data. Default is "ed_visits".
#' @param output_col A string specifying the column name for the output variable
#' in the forecast. Default is "other_ed_visits".
#' @param ... Additional arguments passed to the
#' `epipredict::cdc_baseline_args_list` function.
#' @return A data frame containing the forecasted values with columns for
#' quantile levels, (forecast) dates, and target values
cdc_flat_forecast <- function(data,
target_col = "ed_visits_target",
output_col = "cdc_flat_ed_visits",
...) {
output_sym <- rlang::sym(output_col)
opts <- cdc_baseline_args_list(...)
# coerce data to epiprocess::epi_df format
epi_data <- data |>
filter(data_type == "train") |>
mutate(geo_value = "us", time_value = date) |>
as_epi_df()
# fit the model
cdc_flat_fit <- cdc_baseline_forecaster(epi_data, target_col, opts)
# generate forecast
cdc_flat_forecast <- cdc_flat_fit$predictions |>
pivot_quantiles_longer(.pred_distn) |>
mutate("{output_col}" := .pred) |> # nolint
rename(
quantile_level = quantile_levels, report_date = forecast_date,
date = target_date
) |>
select(date, quantile_level, !!output_sym)

cdc_flat_forecast
}

main <- function(model_run_dir, n_forecast_days = 28, n_samples = 2000) {
# to do: do this with json data that has dates
data_path <- path(model_run_dir, "data", ext = "csv")
Expand All @@ -85,27 +150,43 @@ main <- function(model_run_dir, n_forecast_days = 28, n_samples = 2000) {
select(date,
ed_visits_target = Disease, ed_visits_other = Other,
data_type
) |>
as_tsibble(index = date)

)
## Time series forecasting
## Fit and forecast other (non-target-disease) ED visits using a combination
## ensemble model
forecast_other <- fit_and_forecast(target_and_other_data, n_forecast_days,
n_samples,
target_col = "ed_visits_other", output_col = "other_ed_visits"
)
## Fit and forecast baseline number ED visits using a combination ensemble
# model
forecast_baseline <- fit_and_forecast(target_and_other_data, n_forecast_days,
n_samples,
target_col = "ed_visits_target",
output_col = "baseline_ed_visits"
)
## Generate CDC flat forecast for the target disease number of ED visits
forecast_cdc_flat <- cdc_flat_forecast(target_and_other_data,
target_col = "ed_visits_target",
output_col = "cdc_flat_ed_visits",
data_frequency = "1 day",
aheads = 1:n_forecast_days
)

## Save the forecasted values to parquet files
save_path_other <- path(model_run_dir, "other_ed_visits_forecast",
ext = "parquet"
)
save_path_baseline <- path(model_run_dir, "baseline_ed_visits_forecast",
ext = "parquet"
)
save_path_cdc_flat <- path(model_run_dir, "cdc_flat_ed_visits_forecast",
ext = "parquet"
)

write_parquet(forecast_other, save_path_other)
write_parquet(forecast_baseline, save_path_baseline)
write_parquet(forecast_cdc_flat, save_path_cdc_flat)
}


Expand Down

0 comments on commit 6b3f264

Please sign in to comment.