Skip to content

Commit

Permalink
Add model_runs directory (3rd attempt) (#152)
Browse files Browse the repository at this point in the history
* add model_runs as additional hierarchy

* fix default tag

* fix path in postprocess_state_forecast.R

* remove unused base_dir definition

* correct_path in timeseries_forecasts

* Issue 137: unify argument patterns (#138)

* add check to all subprocess commands (#143)

* Organize helper functions / utilities (#141)

* provide missing namespace

* merge main into model_runs_2 (#146)

* Issue 137: unify argument patterns (#138)

* add check to all subprocess commands (#143)

* Organize helper functions / utilities (#141)

---------

Co-authored-by: Samuel Brand <[email protected]>
Co-authored-by: Dylan H. Morris <[email protected]>

* load required packages

* more namespace fixes

* use hewr functionality

* first collate plots changes

* use parent dir for saving by default

* put figures in figures dir

* update score tables collation

---------

Co-authored-by: Samuel Brand <[email protected]>
Co-authored-by: Dylan H. Morris <[email protected]>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent cdb04fc commit b90dd9f
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 52 deletions.
5 changes: 4 additions & 1 deletion hewr/R/directory_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ parse_model_batch_dir_path <- function(model_batch_dir_path) {
#'
#' @export
parse_model_run_dir_path <- function(model_run_dir_path) {
batch_dir <- fs::path_dir(model_run_dir_path) |>
batch_dir <- model_run_dir_path |>
fs::path_dir() |>
fs::path_dir() |>
fs::path_file()

location <- fs::path_file(model_run_dir_path)

return(c(
Expand Down
2 changes: 1 addition & 1 deletion hewr/tests/testthat/test_directory_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ invalid_model_batch_dirs <- c(

to_valid_run_dir <- function(valid_batch_dir_entry, location) {
x <- valid_batch_dir_entry
x$dirpath <- fs::path(x$dirname, location)
x$dirpath <- fs::path(x$dirname, "model_runs", location)
x$expected <- c(
location = location,
x$expected
Expand Down
2 changes: 1 addition & 1 deletion pipelines/batch/setup_test_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
description="Test production pipeline on small subset of locations"
)
parser.add_argument(
"tag",
"--tag",
type=str,
help="The tag name to use for the container image version",
default=Path(Repository(os.getcwd()).head.name).stem,
Expand Down
55 changes: 44 additions & 11 deletions pipelines/collate_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def merge_pdfs_and_save(
def merge_pdfs_from_subdirs(
base_dir: str | Path,
file_name: str,
save_dir: str | Path = None,
output_file_name: str = None,
subdirs_only: list[str] = None,
subdir_pattern="*",
Expand All @@ -59,6 +60,11 @@ def merge_pdfs_from_subdirs(
Name of the files to merge. Must be an
exact match.
save_dir
Directory in which to save the merged PDF.
If ``None``, use a "figures" directory in the parent directory of ``base_dir``.
Default ``None``.
output_file_name
Name for the merged PDF file, which will be
saved within ``base_dir``. If ``None``,
Expand All @@ -84,6 +90,13 @@ def merge_pdfs_from_subdirs(
-------
None
"""

if save_dir is None:
save_dir = Path(base_dir).parent / "figures"

if not os.path.exists(save_dir):
os.makedirs(save_dir)

subdirs = [
f.name for f in Path(base_dir).glob(subdir_pattern) if f.is_dir()
]
Expand All @@ -101,14 +114,15 @@ def merge_pdfs_from_subdirs(
output_file_name = file_name

if len(to_merge) > 0:
merge_pdfs_and_save(to_merge, Path(base_dir, output_file_name))
merge_pdfs_and_save(to_merge, Path(save_dir, output_file_name))

return None


def process_dir(
dir_path: Path | str,
base_dir: Path | str,
target_filenames: str | list[str],
save_dir: Path | str = None,
file_prefix: str = "",
subdirs_only: list[str] = None,
) -> None:
Expand All @@ -119,14 +133,17 @@ def process_dir(
Parameters
----------
dir_path
Path to the base directory, in which the merged
PDFs will be saved.
base_dir
Path to the base directory in which to look
target_filenames
One or more PDFs filenames to look for in the
subdirectories and merge.
save_dir
Directory in which to save the merged PDFs.
If ``None``, use a "figures" directory in the parent directory of ``base_dir``. Default ``None``.
file_prefix
Prefix to append to the names in `target_filenames`
when naming the merged files.
Expand All @@ -136,17 +153,24 @@ def process_dir(
named subdirectories. If ``None``, look in all
subdirectories of ``base_dir``. Default ``None``.
"""
if save_dir is None:
save_dir = Path(base_dir).parent / "figures"

for file_name in ensure_listlike(target_filenames):
merge_pdfs_from_subdirs(
dir_path,
base_dir,
file_name,
save_dir,
output_file_name=file_prefix + file_name,
subdirs_only=subdirs_only,
)


def collate_from_all_subdirs(
model_base_dir: str | Path, disease: str, target_filenames: str | list[str]
model_base_dir: str | Path,
disease: str,
target_filenames: str | list[str],
save_dir: str | Path = None,
) -> None:
"""
Collate target plots for a given disease
Expand All @@ -156,8 +180,7 @@ def collate_from_all_subdirs(
----------
model_base_dir
Path to the base directory in whose subdirectories
the script will look for PDFs to merge and in which
the merged PDFs will be saved.
the script will look for PDFs to merge.
disease
Name of the target disease. Merged PDFs will be named
Expand All @@ -167,10 +190,17 @@ def collate_from_all_subdirs(
One or more PDFs filenames to look for in the
subdirectories and merge.
save_dir
Directory in which to save the merged PDFs.
If ``None``, use a "figures" directory in the parent directory of ``model_base_dir``. Default ``None``.
Returns
-------
None
"""
if save_dir is None:
save_dir = Path(model_base_dir).parent / "figures"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Expand All @@ -186,8 +216,9 @@ def collate_from_all_subdirs(
for f_dir in forecast_dirs:
logger.info(f"Collating plots from {f_dir}")
process_dir(
dir_path=Path(model_base_dir, f_dir),
base_dir=Path(model_base_dir, f_dir),
target_filenames=target_filenames,
save_dir=save_dir,
)
logger.info("Done collating across locations by date.")

Expand All @@ -197,11 +228,13 @@ def collate_from_all_subdirs(
# for multiple diseases.
logger.info("Collating plots from forecast date directories...")
process_dir(
dir_path=model_base_dir,
base_dir=model_base_dir,
target_filenames=target_filenames,
save_dir=save_dir,
file_prefix=f"{disease}_",
subdirs_only=forecast_dirs,
)

logger.info("Done collating plots from forecast date directories.")

return None
Expand Down
2 changes: 1 addition & 1 deletion pipelines/collate_score_tables.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ collate_scores_for_date <- function(model_run_dir,
score_file_ext = "rds",
save = FALSE) {
message(glue::glue("Processing scores from {model_run_dir}..."))
locations_to_process <- fs::dir_ls(model_run_dir,
locations_to_process <- fs::dir_ls(model_run_dir, "model_runs",
type = "directory"
)
date_score_table <- purrr::map(
Expand Down
4 changes: 2 additions & 2 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def main(

model_batch_dir = Path(output_data_dir, model_batch_dir_name)

model_run_dir = Path(model_batch_dir, state)
model_run_dir = Path(model_batch_dir, "model_runs", state)

os.makedirs(model_run_dir, exist_ok=True)

Expand All @@ -212,7 +212,7 @@ def main(
first_training_date=first_training_date,
last_training_date=last_training_date,
param_estimates=param_estimates,
model_batch_dir=model_batch_dir,
model_run_dir=model_run_dir,
logger=logger,
)
logger.info("Data preparation complete.")
Expand Down
18 changes: 5 additions & 13 deletions pipelines/postprocess_state_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ script_packages <- c(
"tidyr",
"readr",
"here",
"forcats"
"forcats",
"hewr"
)

## load in packages without messages
Expand Down Expand Up @@ -255,12 +256,6 @@ postprocess_state_forecast <- function(model_run_dir) {

theme_set(theme_minimal_grid())

disease_name_formatter <- c("covid-19" = "COVID-19", "influenza" = "Flu")
disease_name_nssp_map <- c(
"covid-19" = "COVID-19",
"influenza" = "Influenza"
)

# Create a parser
p <- arg_parser("Generate forecast figures") |>
add_argument(
Expand All @@ -271,13 +266,10 @@ p <- arg_parser("Generate forecast figures") |>
argv <- parse_args(p)
model_run_dir <- path(argv$model_run_dir)

base_dir <- path_dir(model_run_dir)

disease_name_raw <- base_dir |>
path_file() |>
str_extract("^.+(?=_r_)")
disease_name_nssp <- parse_model_run_dir_path(model_run_dir)$disease

disease_name_nssp <- unname(disease_name_nssp_map[disease_name_raw])
disease_name_pretty <- unname(disease_name_formatter[disease_name_raw])
disease_name_formatter <- c("COVID-19" = "COVID-19", "Influenza" = "Flu")
disease_name_pretty <- unname(disease_name_formatter[disease_name_nssp])

postprocess_state_forecast(model_run_dir)
14 changes: 8 additions & 6 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def process_and_save_state(
first_training_date: datetime.date,
last_training_date: datetime.date,
param_estimates: pl.LazyFrame,
model_batch_dir: Path,
model_run_dir: Path,
logger: Logger = None,
facility_level_nssp_data: pl.LazyFrame = None,
state_level_nssp_data: pl.LazyFrame = None,
Expand Down Expand Up @@ -333,13 +333,15 @@ def process_and_save_state(
"right_truncation_offset": right_truncation_offset,
}

state_dir = os.path.join(model_batch_dir, state_abb)
os.makedirs(state_dir, exist_ok=True)
os.makedirs(model_run_dir, exist_ok=True)

logger.info(f"Saving {state_abb} to {state_dir}")
data_to_save.write_csv(Path(state_dir, "data.csv"))
if logger is not None:
logger.info(f"Saving {state_abb} to {model_run_dir}")
data_to_save.write_csv(Path(model_run_dir, "data.csv"))

with open(Path(state_dir, "data_for_model_fit.json"), "w") as json_file:
with open(
Path(model_run_dir, "data_for_model_fit.json"), "w"
) as json_file:
json.dump(data_for_model_fit, json_file)

return None
30 changes: 14 additions & 16 deletions pipelines/timeseries_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ script_packages <- c(
"arrow",
"glue",
"epipredict",
"epiprocess"
"epiprocess",
"purrr",
"rlang",
"glue",
"hewr"
)

## load in packages without messages
Expand All @@ -29,12 +33,12 @@ to_prop_forecast <- function(forecast_disease_count,
other_count_col =
"other_ed_visits",
output_col = "prop_disease_ed_visits") {
result <- dplyr::inner_join(
result <- inner_join(
forecast_disease_count,
forecast_other_count,
by = c(".draw", "date")
) |>
dplyr::mutate(
mutate(
!!output_col :=
.data[[disease_count_col]] /
(.data[[disease_count_col]] +
Expand Down Expand Up @@ -68,9 +72,9 @@ fit_and_forecast <- function(data,
n_samples = 2000,
target_col = "ed_visits",
output_col = "other_ed_visits") {
forecast_horizon <- glue::glue("{n_forecast_days} days")
target_sym <- rlang::sym(target_col)
output_sym <- rlang::sym(output_col)
forecast_horizon <- glue("{n_forecast_days} days")
target_sym <- sym(target_col)
output_sym <- sym(output_col)

max_visits <- data |>
pull(!!target_sym) |>
Expand Down Expand Up @@ -200,21 +204,21 @@ main <- function(model_run_dir, n_forecast_days = 28, n_samples = 2000) {
aheads = 1:n_forecast_days
)

to_save <- tibble::tribble(
to_save <- tribble(
~basename, ~value,
"other_ed_visits_forecast", forecast_other,
"baseline_ts_count_ed_visits_forecast", baseline_ts_count,
"baseline_ts_prop_ed_visits_forecast", baseline_ts_prop,
"baseline_cdc_count_ed_visits_forecast", baseline_cdc_count,
"baseline_cdc_prop_ed_visits_forecast", baseline_cdc_prop
) |>
dplyr::mutate(save_path = path(
mutate(save_path = path(
!!model_run_dir, basename,
ext = "parquet"
))


purrr::walk2(
walk2(
to_save$value,
to_save$save_path,
write_parquet
Expand Down Expand Up @@ -250,12 +254,6 @@ disease_name_nssp_map <- c(
"influenza" = "Influenza"
)

base_dir <- path_dir(model_run_dir)

disease_name_raw <- base_dir |>
path_file() |>
str_extract("^.+(?=_r_)")

disease_name_nssp <- unname(disease_name_nssp_map[disease_name_raw])
disease_name_nssp <- parse_model_run_dir_path(model_run_dir)$disease

main(model_run_dir, n_forecast_days, n_samples)

0 comments on commit b90dd9f

Please sign in to comment.