Skip to content

Commit

Permalink
Modularize Post Processing (#55)
Browse files Browse the repository at this point in the history
* modular scripts

* make all_post_process executable

* working on post_process

* waiting for renv update to test

* option to filter bad chains and workaround for command line

* parse args

* more command line args
  • Loading branch information
damonbayer authored Oct 21, 2024
1 parent 42c3891 commit cc7f156
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 51 deletions.
28 changes: 28 additions & 0 deletions nssp_demo/all_post_process.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash

# Check if the base directory is provided as an argument
if [ -z "$1" ]; then
echo "Usage: $0 <base_dir>"
exit 1
fi

# Base directory containing subdirectories
BASE_DIR="$1"


# Iterate over each subdirectory in the base directory
for SUBDIR in "$BASE_DIR"/*/; do
# Run the R script with the current subdirectory as the model_dir argument
echo "$SUBDIR"
# will work once https://github.com/rstudio/renv/pull/2018 is merged
Rscript -e "renv::run(\"post_process.R\", project = \"..\", args = c(\"--model_dir ${SUBDIR}\"))"
done


# # Get the name of the current directory (base_dir)
base_dir_name=$(basename "$(pwd)")

# Find all forecast_plot.pdf files and combine them using pdfunite
find . -name "forecast_plot.pdf" | sort | xargs pdfunite - "${BASE_DIR}/${base_dir_name}_all_forecasts.pdf"

echo "Combined PDF created: ${base_dir_name}_all_forecasts.pdf"
153 changes: 102 additions & 51 deletions nssp_demo/post_process.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,86 @@ library(cowplot)
library(glue)
library(scales)
library(here)
library(argparser)

# Create a parser
p <- arg_parser("Generate forecast figures") %>%
add_argument(p, "--model_dir",
help = "Directory containing the model data",
required = TRUE
) %>%
add_argument(p, "--filter_bad_chains",
help = "Filter out bad chains from the samples",
flag = TRUE
) %>%
add_argument(p, "--good_chain_tol",
help = "Tolerance level for determining good chains",
default = 2
)

argv <- parse_args(p)
model_dir <- path(argv$model_dir)
filter_bad_chains <- argv$filter_bad_chains
good_chain_tol <- argv$good_chain_tol

base_dir <- path_dir(model_dir)

theme_set(theme_minimal_grid())

disease_name_formatter <- c("covid-19" = "COVID-19", "influenza" = "Flu")

make_forecast_fig <- function(model_dir) {
read_pyrenew_samples <- function(inference_data_path,
filter_bad_chains = TRUE,
good_chain_tol = 2) {
arviz_split <- function(x) {
x %>%
select(-distribution) %>%
split(f = as.factor(x$distribution))
}

pyrenew_samples <-
read_csv(inference_data_path) %>%
rename_with(\(varname) str_remove_all(varname, "\\(|\\)|\\'|(, \\d+)")) |>
rename(
.chain = chain,
.iteration = draw
) |>
mutate(across(c(.chain, .iteration), \(x) as.integer(x + 1))) |>
mutate(
.draw = tidybayes:::draw_from_chain_and_iteration_(.chain, .iteration),
.after = .iteration
) |>
pivot_longer(-starts_with("."),
names_sep = ", ",
names_to = c("distribution", "name")
) |>
arviz_split() |>
map(\(x) pivot_wider(x, names_from = name) |> tidy_draws())

if (filter_bad_chains) {
good_chains <-
pyrenew_samples$log_likelihood %>%
pivot_longer(-starts_with(".")) %>%
group_by(.iteration, .chain) %>%
summarize(value = sum(value)) %>%
group_by(.chain) %>%
summarize(value = mean(value)) %>%
filter(value >= max(value) - 2) %>%
pull(.chain)
} else {
good_chains <- unique(pyrenew_samples$log_likelihood$.chain)
}

good_pyrenew_samples <- map(
pyrenew_samples,
\(x) filter(x, .chain %in% good_chains)
)
good_pyrenew_samples
}

make_forecast_fig <- function(model_dir,
filter_bad_chains = TRUE,
good_chain_tol = 2) {
disease_name_raw <- base_dir %>%
path_file() %>%
str_extract("^.+(?=_r_)")
Expand All @@ -20,17 +94,15 @@ make_forecast_fig <- function(model_dir) {
pluck(1) %>%
tail(1)


data_path <- path(model_dir, "data", ext = "csv")
inference_data_path <- path(model_dir, "inference_data",
ext = "csv"
)


dat <- read_csv(data_path) %>%
arrange(date) %>%
mutate(time = row_number() - 1) %>%
rename(.value = COVID_ED_admissions)
rename(.value = ED_admissions)

last_training_date <- dat %>%
filter(data_type == "train") %>%
Expand All @@ -41,31 +113,11 @@ make_forecast_fig <- function(model_dir) {
pull(date) %>%
max()

arviz_split <- function(x) {
x %>%
select(-distribution) %>%
split(f = as.factor(x$distribution))
}

pyrenew_samples <-
read_csv(inference_data_path) %>%
rename_with(\(varname) str_remove_all(varname, "\\(|\\)|\\'|(, \\d+)")) |>
rename(
.chain = chain,
.iteration = draw
) |>
mutate(across(c(.chain, .iteration), \(x) as.integer(x + 1))) |>
mutate(
.draw = tidybayes:::draw_from_chain_and_iteration_(.chain, .iteration),
.after = .iteration
) |>
pivot_longer(-starts_with("."),
names_sep = ", ",
names_to = c("distribution", "name")
) |>
arviz_split() |>
map(\(x) pivot_wider(x, names_from = name) |> tidy_draws())

pyrenew_samples <- read_pyrenew_samples(inference_data_path,
filter_bad_chains = filter_bad_chains,
good_chain_tol = good_chain_tol
)

hosp_ci <-
pyrenew_samples$posterior_predictive %>%
Expand Down Expand Up @@ -118,35 +170,34 @@ make_forecast_fig <- function(model_dir) {
forecast_plot
}

forecast_fig <- make_forecast_fig(model_dir, filter_bad_chains, good_chain_tol)

base_dir <- path(here(
save_plot(
filename = path(model_dir, "forecast_plot", ext = "pdf"),
plot = forecast_fig,
device = cairo_pdf, base_height = 6
)


# Temp code while command line version doesn't work
base_dir <- path(
"nssp_demo",
"private_data",
"covid-19_r_2024-10-10_f_2024-04-12_l_2024-10-09_t_2024-10-05"
))
"influenza_r_2024-10-10_f_2024-04-12_l_2024-10-09_t_2024-10-05"
)

walk(dir_ls(base_dir), function(model_dir) {
forecast_fig <- make_forecast_fig(model_dir)

forecast_fig_tbl <-
tibble(base_model_dir = dir_ls(base_dir)) %>%
filter(
path(base_model_dir, "inference_data", ext = "csv") %>%
file_exists()
) %>%
mutate(forecast_fig = map(base_model_dir, make_forecast_fig)) %>%
mutate(figure_path = path(base_model_dir, "forecast_plot", ext = "pdf"))

pwalk(
forecast_fig_tbl %>% select(forecast_fig, figure_path),
function(forecast_fig, figure_path) {
save_plot(
filename = figure_path,
plot = forecast_fig,
device = cairo_pdf, base_height = 6
)
}
)
save_plot(
filename = path(model_dir, "forecast_plot", ext = "pdf"),
plot = forecast_fig,
device = cairo_pdf, base_height = 6
)
})

str_c(forecast_fig_tbl$figure_path, collapse = " ") %>%
path(dir_ls(base_dir, type = "directory"), "forecast_plot", ext = "pdf") %>%
str_c(collapse = " ") %>%
str_c(
path(base_dir,
glue("{path_file(base_dir)}_all_forecasts"),
Expand Down

0 comments on commit cc7f156

Please sign in to comment.