Skip to content

Commit

Permalink
Revised forecast plots (#99)
Browse files Browse the repository at this point in the history
revised figures
  • Loading branch information
damonbayer authored Nov 1, 2024
1 parent f4c4091 commit 5c04591
Showing 1 changed file with 56 additions and 38 deletions.
94 changes: 56 additions & 38 deletions nssp_demo/postprocess_state_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ script_packages <- c(
"arrow",
"tidyr",
"readr",
"here"
"here",
"forcats"
)

## load in packages without messages
Expand All @@ -24,7 +25,7 @@ purrr::walk(script_packages, \(pkg) {


# To be replaced with reading tidy data from forecasttools
read_pyrenew_samples <- function(inference_data_path,
read_pyrenew_samples <- function(inference_train_data_path,
filter_bad_chains = TRUE,
good_chain_tol = 2) {
arviz_split <- function(x) {
Expand All @@ -34,7 +35,7 @@ read_pyrenew_samples <- function(inference_data_path,
}

pyrenew_samples <-
read_csv(inference_data_path,
read_csv(inference_train_data_path,
show_col_types = FALSE
) |>
rename_with(\(varname) str_remove_all(varname, "\\(|\\)|\\'|(, \\d+)")) |>
Expand Down Expand Up @@ -76,9 +77,9 @@ read_pyrenew_samples <- function(inference_data_path,
}

make_one_forecast_fig <- function(target_disease,
dat,
combined_dat,
last_training_date,
last_data_date,
data_vintage_date,
posterior_predictive_ci,
state_abb) {
y_scale <- if (str_starts(target_disease, "prop")) {
Expand All @@ -99,11 +100,28 @@ make_one_forecast_fig <- function(target_disease,
geom_lineribbon(
data = posterior_predictive_ci |> filter(disease == target_disease),
mapping = aes(ymin = .lower, ymax = .upper),
color = "#08519c", key_glyph = draw_key_rect, step = "mid"
color = "#08519c",
key_glyph = draw_key_rect,
step = "mid"
) +
scale_fill_brewer(
name = "Credible Interval Width",
labels = ~ percent(as.numeric(.))
) +
geom_point(
mapping = aes(shape = data_type),
data = dat |> filter(disease == target_disease)
mapping = aes(color = data_type), size = 1.5,
data = combined_dat |>
filter(
disease == target_disease,
date <= max(posterior_predictive_ci$date)
) |>
mutate(data_type = fct_rev(data_type)) |>
arrange(desc(data_type))
) +
scale_color_manual(
name = "Data Type",
values = c("olivedrab1", "deeppink"),
labels = str_to_title
) +
geom_vline(xintercept = last_training_date, linetype = "dashed") +
annotate(
Expand All @@ -121,14 +139,10 @@ make_one_forecast_fig <- function(target_disease,
hjust = "left",
vjust = "bottom",
) +
ggtitle(title, subtitle = glue("as of {last_data_date}")) +
ggtitle(title, subtitle = glue("as of {data_vintage_date}")) +
y_scale +
scale_x_date("Date") +
scale_shape_discrete("Data Type", labels = str_to_title) +
scale_fill_brewer(
name = "Credible Interval Width",
labels = ~ percent(as.numeric(.))
) +
# scale_shape_discrete("Data Type", labels = str_to_title) +
theme(legend.position = "bottom")
}

Expand All @@ -141,8 +155,9 @@ postprocess_state_forecast <- function(model_run_dir,
pluck(1) |>
tail(1)

data_path <- path(model_run_dir, "data", ext = "csv")
inference_data_path <- path(model_run_dir, "inference_data",
train_data_path <- path(model_run_dir, "data", ext = "csv")
eval_data_path <- path(model_run_dir, "eval_data", ext = "tsv")
inference_train_data_path <- path(model_run_dir, "inference_data",
ext = "csv"
)
other_ed_visits_path <- path(
Expand All @@ -151,15 +166,21 @@ postprocess_state_forecast <- function(model_run_dir,
ext = "parquet"
)

dat <- read_csv(
data_path,
col_types = cols(
disease = col_character(),
data_type = col_character(),
ed_visits = col_double(),
date = col_date()
)
) |>
train_dat <- read_csv(train_data_path, show_col_types = FALSE)

data_vintage_date <- max(train_dat$date) + 1
# this should be stored as metadata somewhere else, instead of being
# computed like this

eval_dat <- read_tsv(eval_data_path, show_col_types = FALSE) |>
mutate(data_type = "eval")

combined_dat <-
bind_rows(
train_dat |>
filter(data_type == "train"),
eval_dat
) |>
mutate(
disease = if_else(
disease == disease_name_nssp,
Expand All @@ -180,16 +201,13 @@ postprocess_state_forecast <- function(model_run_dir,
values_to = ".value"
)

last_training_date <- dat |>
filter(data_type == "train") |>
pull(date) |>
max()

last_data_date <- dat |>
last_training_date <- combined_dat |>
filter(data_type == "train") |>
pull(date) |>
max()

pyrenew_samples <- read_pyrenew_samples(inference_data_path,
pyrenew_samples <- read_pyrenew_samples(inference_train_data_path,
filter_bad_chains = filter_bad_chains,
good_chain_tol = good_chain_tol
)
Expand All @@ -198,11 +216,11 @@ postprocess_state_forecast <- function(model_run_dir,
read_parquet(other_ed_visits_path) |>
rename(Other = other_ed_visits)


other_ed_visits_samples <-
bind_rows(
dat |>
combined_dat |>
filter(
data_type == "eval",
disease == "Other",
date <= last_training_date
) |>
Expand All @@ -217,7 +235,7 @@ postprocess_state_forecast <- function(model_run_dir,
pivot_wider(names_from = .variable, values_from = .value) |>
rename(Disease = observed_hospital_admissions) |>
ungroup() |>
mutate(date = min(dat$date) + time) |>
mutate(date = min(combined_dat$date) + time) |>
left_join(other_ed_visits_samples,
by = c(".draw", "date")
) |>
Expand Down Expand Up @@ -250,12 +268,12 @@ postprocess_state_forecast <- function(model_run_dir,


all_forecast_plots <- map(
set_names(unique(dat$disease)),
set_names(unique(combined_dat$disease)),
~ make_one_forecast_fig(
.x,
dat,
combined_dat,
last_training_date,
last_data_date,
data_vintage_date,
posterior_predictive_ci,
state_abb
)
Expand Down Expand Up @@ -285,7 +303,7 @@ p <- arg_parser("Generate forecast figures") |>
) |>
add_argument(
"--filter-bad-chains",
help = "Filter out bad chains from the samples? Default TRUE.",
help = "Filter out bad chains from the samples? Default FALSE",
flag = TRUE
) |>
add_argument(
Expand Down

0 comments on commit 5c04591

Please sign in to comment.