From 5c0459158febfcfcbf668113e760fdbb0011c1f8 Mon Sep 17 00:00:00 2001 From: Damon Bayer Date: Fri, 1 Nov 2024 18:37:24 -0500 Subject: [PATCH] Revised forecast plots (#99) revised figures --- nssp_demo/postprocess_state_forecast.R | 94 +++++++++++++++----------- 1 file changed, 56 insertions(+), 38 deletions(-) diff --git a/nssp_demo/postprocess_state_forecast.R b/nssp_demo/postprocess_state_forecast.R index 40a03e10..ac62ec1a 100644 --- a/nssp_demo/postprocess_state_forecast.R +++ b/nssp_demo/postprocess_state_forecast.R @@ -12,7 +12,8 @@ script_packages <- c( "arrow", "tidyr", "readr", - "here" + "here", + "forcats" ) ## load in packages without messages @@ -24,7 +25,7 @@ purrr::walk(script_packages, \(pkg) { # To be replaced with reading tidy data from forecasttools -read_pyrenew_samples <- function(inference_data_path, +read_pyrenew_samples <- function(inference_train_data_path, filter_bad_chains = TRUE, good_chain_tol = 2) { arviz_split <- function(x) { @@ -34,7 +35,7 @@ read_pyrenew_samples <- function(inference_data_path, } pyrenew_samples <- - read_csv(inference_data_path, + read_csv(inference_train_data_path, show_col_types = FALSE ) |> rename_with(\(varname) str_remove_all(varname, "\\(|\\)|\\'|(, \\d+)")) |> @@ -76,9 +77,9 @@ read_pyrenew_samples <- function(inference_data_path, } make_one_forecast_fig <- function(target_disease, - dat, + combined_dat, last_training_date, - last_data_date, + data_vintage_date, posterior_predictive_ci, state_abb) { y_scale <- if (str_starts(target_disease, "prop")) { @@ -99,11 +100,28 @@ make_one_forecast_fig <- function(target_disease, geom_lineribbon( data = posterior_predictive_ci |> filter(disease == target_disease), mapping = aes(ymin = .lower, ymax = .upper), - color = "#08519c", key_glyph = draw_key_rect, step = "mid" + color = "#08519c", + key_glyph = draw_key_rect, + step = "mid" + ) + + scale_fill_brewer( + name = "Credible Interval Width", + labels = ~ percent(as.numeric(.)) ) + geom_point( - mapping = aes(shape = data_type), - data = dat |> filter(disease == target_disease) + mapping = aes(color = data_type), size = 1.5, + data = combined_dat |> + filter( + disease == target_disease, + date <= max(posterior_predictive_ci$date) + ) |> + mutate(data_type = fct_rev(data_type)) |> + arrange(desc(data_type)) + ) + + scale_color_manual( + name = "Data Type", + values = c("olivedrab1", "deeppink"), + labels = str_to_title ) + geom_vline(xintercept = last_training_date, linetype = "dashed") + annotate( @@ -121,14 +139,10 @@ make_one_forecast_fig <- function(target_disease, hjust = "left", vjust = "bottom", ) + - ggtitle(title, subtitle = glue("as of {last_data_date}")) + + ggtitle(title, subtitle = glue("as of {data_vintage_date}")) + y_scale + scale_x_date("Date") + - scale_shape_discrete("Data Type", labels = str_to_title) + - scale_fill_brewer( - name = "Credible Interval Width", - labels = ~ percent(as.numeric(.)) - ) + + # scale_shape_discrete("Data Type", labels = str_to_title) + theme(legend.position = "bottom") } @@ -141,8 +155,9 @@ postprocess_state_forecast <- function(model_run_dir, pluck(1) |> tail(1) - data_path <- path(model_run_dir, "data", ext = "csv") - inference_data_path <- path(model_run_dir, "inference_data", + train_data_path <- path(model_run_dir, "data", ext = "csv") + eval_data_path <- path(model_run_dir, "eval_data", ext = "tsv") + inference_train_data_path <- path(model_run_dir, "inference_data", ext = "csv" ) other_ed_visits_path <- path( @@ -151,15 +166,21 @@ postprocess_state_forecast <- function(model_run_dir, ext = "parquet" ) - dat <- read_csv( - data_path, - col_types = cols( - disease = col_character(), - data_type = col_character(), - ed_visits = col_double(), - date = col_date() - ) - ) |> + train_dat <- read_csv(train_data_path, show_col_types = FALSE) + + data_vintage_date <- max(train_dat$date) + 1 + # this should be stored as metadata somewhere else, instead of being + # computed like this + + eval_dat <- read_tsv(eval_data_path, show_col_types = FALSE) |> + mutate(data_type = "eval") + + combined_dat <- + bind_rows( + train_dat |> + filter(data_type == "train"), + eval_dat + ) |> mutate( disease = if_else( disease == disease_name_nssp, @@ -180,16 +201,13 @@ postprocess_state_forecast <- function(model_run_dir, values_to = ".value" ) - last_training_date <- dat |> - filter(data_type == "train") |> - pull(date) |> - max() - last_data_date <- dat |> + last_training_date <- combined_dat |> + filter(data_type == "train") |> pull(date) |> max() - pyrenew_samples <- read_pyrenew_samples(inference_data_path, + pyrenew_samples <- read_pyrenew_samples(inference_train_data_path, filter_bad_chains = filter_bad_chains, good_chain_tol = good_chain_tol ) @@ -198,11 +216,11 @@ postprocess_state_forecast <- function(model_run_dir, read_parquet(other_ed_visits_path) |> rename(Other = other_ed_visits) - other_ed_visits_samples <- bind_rows( - dat |> + combined_dat |> filter( + data_type == "eval", disease == "Other", date <= last_training_date ) |> @@ -217,7 +235,7 @@ postprocess_state_forecast <- function(model_run_dir, pivot_wider(names_from = .variable, values_from = .value) |> rename(Disease = observed_hospital_admissions) |> ungroup() |> - mutate(date = min(dat$date) + time) |> + mutate(date = min(combined_dat$date) + time) |> left_join(other_ed_visits_samples, by = c(".draw", "date") ) |> @@ -250,12 +268,12 @@ postprocess_state_forecast <- function(model_run_dir, all_forecast_plots <- map( - set_names(unique(dat$disease)), + set_names(unique(combined_dat$disease)), ~ make_one_forecast_fig( .x, - dat, + combined_dat, last_training_date, - last_data_date, + data_vintage_date, posterior_predictive_ci, state_abb ) @@ -285,7 +303,7 @@ p <- arg_parser("Generate forecast figures") |> ) |> add_argument( "--filter-bad-chains", - help = "Filter out bad chains from the samples? Default TRUE.", + help = "Filter out bad chains from the samples? Default FALSE", flag = TRUE ) |> add_argument(