Skip to content

Commit

Permalink
Post-processing for H models (#287)
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer authored Feb 7, 2025
1 parent 3a0a5bf commit 216a81b
Show file tree
Hide file tree
Showing 37 changed files with 1,246 additions and 1,750 deletions.
1 change: 1 addition & 0 deletions hewr/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Imports:
tibble,
tidybayes,
tidyr,
tidyselect,
urca
Remotes:
https://github.com/cdcgov/forecasttools
Expand Down
8 changes: 5 additions & 3 deletions hewr/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
export(combine_training_and_eval_data)
export(generate_exp_growth_pois)
export(get_all_model_batch_dirs)
export(group_time_index_to_date)
export(make_forecast_figure)
export(parse_model_batch_dir_path)
export(parse_model_run_dir_path)
export(pivot_ed_visit_df_longer)
export(parse_pyrenew_model_name)
export(parse_variable_name)
export(process_state_forecast)
export(read_and_combine_data)
export(score_hubverse)
export(to_epiweekly_quantile_table)
export(to_epiweekly_quantiles)
export(with_prop_disease_ed_visits)
importFrom(rlang,":=")
importFrom(rlang,.data)
60 changes: 60 additions & 0 deletions hewr/R/directory_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,63 @@ get_all_model_batch_dirs <- function(dir_of_batch_dirs,

return(dirs)
}

#' Parse PyRenew Model Name
#'
#' @param pyrenew_model_name name of a pyrenew model ("pyrenew_h", "pyrenew_he",
#' "pyrnew_hew", etc)
#'
#' @returns a named logical vector indicating which components are present
#' @export
#'
#' @examples parse_pyrenew_model_name("pyrenew_h")
parse_pyrenew_model_name <- function(pyrenew_model_name) {
pyrenew_model_tail <- stringr::str_extract(pyrenew_model_name, "(?<=_).+$") |>
stringr::str_split_1("")
model_components <- c("h", "e", "w")
model_components %in% pyrenew_model_tail |> purrr::set_names(model_components)
}


#' Parse variable name.
#'
#' Convert a variable name into a descriptive label for display in plots.
#'
#' @param variable_name Character. Name of the variable to parse.
#' @return A list containing:
#' - `proportion`: Logical. Indicates if the variable represents a proportion.
#' - `core_name`: Character. A simplified name for the variable.
#' - `full_name`: Character. A formatted name for the variable.
#' - `y_axis_labels`: Function. A suitable label function for axis formatting.
#' @export
#'
#' @examples
#' parse_variable_name("prop_hospital_admissions")
parse_variable_name <- function(variable_name) {
proportion <- stringr::str_starts(variable_name, "prop")

core_name <- dplyr::case_when(
stringr::str_detect(variable_name, "ed_visits") ~
"Emergency Department Visits",
stringr::str_detect(variable_name, "hospital") ~ "Hospital Admissions",
TRUE ~ ""
)

full_name <- dplyr::if_else(proportion,
glue::glue("Proportion of {core_name}"),
core_name
)

y_axis_labels <- if (proportion) {
scales::label_percent()
} else {
scales::label_comma()
}

list(
proportion = proportion,
core_name = core_name,
full_name = full_name,
y_axis_labels = y_axis_labels
)
}
2 changes: 1 addition & 1 deletion hewr/R/hewr-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
"_PACKAGE"

## usethis namespace: start
#' @importFrom rlang .data
#' @importFrom rlang .data :=
## usethis namespace: end
NULL
59 changes: 27 additions & 32 deletions hewr/R/make_forecast_figure.R
Original file line number Diff line number Diff line change
@@ -1,59 +1,51 @@
#' Make Forecast Figure
#'
#' @param target_disease a disease matching the disease columns
#' @param target_variable a variable matching the .variable columns
#' in `combined_dat` and `forecast_ci`
#' @param combined_dat `combined_dat` from the result of
#' [process_state_forecast()]
#' @param forecast_ci `forecast_ci` from the result of
#' [process_state_forecast()]
#' @param disease_name `"COVID-19"` or `"Influenza"`
#' @param data_vintage_date date that the data was collected
#' @param y_transform a character passed as the transform argument to
#' [ggplot2::scale_y_continuous()].
#'
#' @return a ggplot object
#' @export
make_forecast_figure <- function(target_disease,
make_forecast_figure <- function(target_variable,
combined_dat,
forecast_ci,
disease_name = c("COVID-19", "Influenza"),
data_vintage_date,
y_transform = "identity") {
disease_name <- rlang::arg_match(disease_name)
disease_name <- forecast_ci[["disease"]][1]
disease_name_pretty <- c(
"COVID-19" = "COVID-19",
"Influenza" = "Flu"
)[disease_name]
state_abb <- unique(combined_dat$geo_value)[1]

y_scale <- if (stringr::str_starts(target_disease, "prop")) {
ggplot2::scale_y_continuous("Proportion of Emergency Department Visits",
labels = scales::label_percent(),
transform = y_transform
)
} else {
ggplot2::scale_y_continuous("Emergency Department Visits",
labels = scales::label_comma(),
transform = y_transform
)
}
state_abb <- unique(combined_dat$geo_value)[1]
parsed_variable_name <- parse_variable_name(target_variable)

y_axis_label <- parsed_variable_name[["full_name"]]
y_axis_labels <- parsed_variable_name[["y_axis_labels"]]
core_name <- parsed_variable_name[["core_name"]]

title <- if (target_disease == "Other") {
glue::glue("Other ED Visits in {state_abb}")
} else {
glue::glue("{disease_name_pretty} ED Visits in {state_abb}")
}
title_prefix <- ifelse(
stringr::str_starts(target_variable, "other"),
"Other",
disease_name_pretty
)
title <- glue::glue("{title_prefix} {core_name} in {state_abb}")

last_training_date <- combined_dat |>
dplyr::filter(data_type == "train") |>
dplyr::filter(.data$data_type == "train") |>
dplyr::pull(date) |>
max()

ggplot2::ggplot(mapping = ggplot2::aes(date, .value)) +
ggplot2::ggplot(mapping = ggplot2::aes(.data$date, .data$.value)) +
ggdist::geom_lineribbon(
data = forecast_ci |> dplyr::filter(disease == target_disease),
mapping = ggplot2::aes(ymin = .lower, ymax = .upper),
data = forecast_ci |> dplyr::filter(.data$.variable == target_variable),
mapping = ggplot2::aes(ymin = .data$.lower, ymax = .data$.upper),
color = "#08519c",
key_glyph = ggplot2::draw_key_rect,
step = "mid"
Expand All @@ -63,14 +55,14 @@ make_forecast_figure <- function(target_disease,
labels = ~ scales::label_percent()(as.numeric(.))
) +
ggplot2::geom_point(
mapping = ggplot2::aes(color = data_type), size = 1.5,
mapping = ggplot2::aes(color = .data$data_type), size = 1.5,
data = combined_dat |>
dplyr::filter(
disease == target_disease,
date <= max(forecast_ci$date)
.data$.variable == target_variable,
.data$date <= max(forecast_ci$date)
) |>
dplyr::mutate(data_type = forcats::fct_rev(data_type)) |>
dplyr::arrange(dplyr::desc(data_type))
dplyr::mutate(data_type = forcats::fct_rev(.data$data_type)) |>
dplyr::arrange(dplyr::desc(.data$data_type))
) +
ggplot2::scale_color_manual(
name = "Data Type",
Expand All @@ -96,7 +88,10 @@ make_forecast_figure <- function(target_disease,
ggplot2::ggtitle(title,
subtitle = glue::glue("as of {data_vintage_date}")
) +
y_scale +
ggplot2::scale_y_continuous(y_axis_label,
labels = y_axis_labels,
transform = y_transform
) +
ggplot2::scale_x_date("Date") +
cowplot::theme_minimal_grid() +
ggplot2::theme(legend.position = "bottom")
Expand Down
Loading

0 comments on commit 216a81b

Please sign in to comment.