Skip to content

Commit

Permalink
Issue 92: add baseline timeseries forecaster for target (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 authored Oct 31, 2024
1 parent 2949f9f commit 24872e4
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 118 deletions.
114 changes: 0 additions & 114 deletions nssp_demo/forecast_non_target_visits.R

This file was deleted.

8 changes: 4 additions & 4 deletions nssp_demo/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from generate_predictive import generate_and_save_predictions # noqa


def forecast_denominator(
def baseline_forecasts(
model_run_dir: Path, n_forecast_days: int, n_samples: int
) -> None:
subprocess.run(
[
"Rscript",
"nssp_demo/forecast_non_target_visits.R",
"nssp_demo/timeseries_forecasts.R",
"--model-run-dir",
f"{model_run_dir}",
"--n-forecast-days",
Expand Down Expand Up @@ -187,9 +187,9 @@ def main(
logger.info("Performing posterior prediction / forecasting...")
generate_and_save_predictions(model_run_dir, n_forecast_days)

logger.info("Performing non-target pathogen forecasting...")
logger.info("Performing baseline forecasting and non-target pathogen forecasting...")
n_denominator_samples = n_samples * n_chains
forecast_denominator(model_run_dir, n_forecast_days, n_denominator_samples)
baseline_forecasts(model_run_dir, n_forecast_days, n_denominator_samples)
logger.info("Forecasting complete.")

logger.info("Postprocessing forecast...")
Expand Down
148 changes: 148 additions & 0 deletions nssp_demo/timeseries_forecasts.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
script_packages <- c(
"dplyr",
"tidyr",
"tibble",
"readr",
"stringr",
"fs",
"fable",
"jsonlite",
"argparser",
"arrow"
)

## load in packages without messages
purrr::walk(script_packages, \(pkg) {
suppressPackageStartupMessages(
library(pkg, character.only = TRUE)
)
})

#' Fit and Forecast Time Series Data
#'
#' This function fits a combination ensemble model to the training data and
#' generates forecast samples for a specified number of days.
#'
#' @param data A data frame containing the time series data. It should have a
#' column named `data_type` to distinguish between training and other data.
#' @param n_forecast_days An integer specifying the number of days to forecast.
#' Default is 28.
#' @param n_samples An integer specifying the number of forecast samples to
#' generate. Default is 2000.
#' @param target_col A string specifying the name of the target column in the
#' data. Default is "ed_visits".
#' @param output_col A string specifying the name of the output column for the
#' forecasted values. Default is "other_ed_visits".
#
#' @return A tibble containing the forecast samples with columns for date,
#' draw number, and forecasted values.
fit_and_forecast <- function(data,
n_forecast_days = 28,
n_samples = 2000,
target_col = "ed_visits",
output_col = "other_ed_visits") {
forecast_horizon <- glue::glue("{n_forecast_days} days")
target_sym <- rlang::sym(target_col)
output_sym <- rlang::sym(output_col)
fit <-
data |>
filter(data_type == "train") |>
model(
comb_model = combination_ensemble(
ETS(log(!!target_sym) ~ trend(method = c("N", "M", "A"))),
ARIMA(log(!!target_sym))
)
)

forecast_samples <- fit |>
generate(h = forecast_horizon, times = n_samples) |>
as_tibble() |>
mutate("{output_col}" := .sim, .draw = as.integer(.rep)) |> # nolint
select(date, .draw, !!output_sym)

forecast_samples
}

main <- function(model_run_dir, n_forecast_days = 28, n_samples = 2000) {
# to do: do this with json data that has dates
data_path <- path(model_run_dir, "data", ext = "csv")

target_and_other_data <- read_csv(
data_path,
col_types = cols(
disease = col_character(),
data_type = col_character(),
ed_visits = col_double(),
date = col_date()
)
) |>
mutate(disease = if_else(
disease == disease_name_nssp,
"Disease", disease
)) |>
pivot_wider(names_from = disease, values_from = ed_visits) |>
mutate(Other = Total - Disease) |>
select(date,
ed_visits_target = Disease, ed_visits_other = Other,
data_type
) |>
as_tsibble(index = date)

forecast_other <- fit_and_forecast(target_and_other_data, n_forecast_days,
n_samples,
target_col = "ed_visits_other", output_col = "other_ed_visits"
)
forecast_baseline <- fit_and_forecast(target_and_other_data, n_forecast_days,
n_samples,
target_col = "ed_visits_target",
output_col = "baseline_ed_visits"
)

save_path_other <- path(model_run_dir, "other_ed_visits_forecast",
ext = "parquet"
)
save_path_baseline <- path(model_run_dir, "baseline_ed_visits_forecast",
ext = "parquet"
)
write_parquet(forecast_other, save_path_other)
write_parquet(forecast_baseline, save_path_baseline)
}


p <- arg_parser(
"Forecast other (non-target-disease) ED visits for a given location."
) |>
add_argument(
"--model-run-dir",
help = "Directory containing the model data and output.",
) |>
add_argument(
"--n-forecast-days",
help = "Number of days to forecast.",
default = 28L
) |>
add_argument(
"--n-samples",
help = "Number of samples to generate.",
default = 2000L
)

argv <- parse_args(p)
model_run_dir <- path(argv$model_run_dir)
n_forecast_days <- argv$n_forecast_days
n_samples <- argv$n_samples

disease_name_nssp_map <- c(
"covid-19" = "COVID-19",
"influenza" = "Influenza"
)

base_dir <- path_dir(model_run_dir)

disease_name_raw <- base_dir |>
path_file() |>
str_extract("^.+(?=_r_)")

disease_name_nssp <- unname(disease_name_nssp_map[disease_name_raw])

main(model_run_dir, n_forecast_days, n_samples)

0 comments on commit 24872e4

Please sign in to comment.