diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d4ab5a48..7bf26f02 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: ##### # R - repo: https://github.com/lorenzwalthert/precommit - rev: v0.4.3.9001 + rev: v0.4.3 hooks: - id: style-files - id: lintr diff --git a/hewr/R/to_epiweekly_quantile_table.R b/hewr/R/to_epiweekly_quantile_table.R index a08353ea..b3285c94 100644 --- a/hewr/R/to_epiweekly_quantile_table.R +++ b/hewr/R/to_epiweekly_quantile_table.R @@ -65,7 +65,8 @@ to_epiweekly_quantiles <- function(model_run_dir, value_col = ".value" ) |> dplyr::mutate( - "location" = !!location + "location" = !!location, + "source_samples" = !!draws_file_name ) message(glue::glue("Done processing {model_run_dir}")) return(epiweekly_quantiles) @@ -98,7 +99,7 @@ to_epiweekly_quantile_table <- function(model_batch_dir, epiweekly_other_locations = c()) { model_runs_path <- fs::path(model_batch_dir, "model_runs") - locations_to_process <- fs::dir_ls(model_runs_path, + model_run_dirs_to_process <- fs::dir_ls(model_runs_path, type = "directory" ) |> purrr::discard(~ fs::path_file(.x) %in% exclude) @@ -122,15 +123,25 @@ to_epiweekly_quantile_table <- function(model_batch_dir, day_of_week = 7 ) - get_location_table <- \(loc) { - epiweekly_other <- loc %in% epiweekly_other_locations + get_location_table <- \(model_run_dir) { + loc <- fs::path_file(model_run_dir) + use_epiweekly_other <- loc %in% epiweekly_other_locations + which_forecast <- ifelse(use_epiweekly_other, + "explicitly epiweekly", + "aggregated daily" + ) + glue::glue( + "Using {which_forecast} non-target ED visit forecast ", + "for location {loc}" + ) + draws_file <- ifelse( - epiweekly_other, + use_epiweekly_other, "epiweekly_with_epiweekly_other_samples", "epiweekly_samples" ) return(to_epiweekly_quantiles( - loc, + model_run_dir, report_date = report_date, max_lookback_days = 15, draws_file_name = draws_file, @@ -138,16 +149,24 @@ to_epiweekly_quantile_table <- function(model_batch_dir, )) } - hubverse_table <- purrr::map( - locations_to_process, + quant_table <- purrr::map( + model_run_dirs_to_process, get_location_table ) |> - dplyr::bind_rows() |> + dplyr::bind_rows() + + loc_sources <- quant_table |> + dplyr::distinct(.data$location, .data$source_samples) + + hubverse_table <- quant_table |> forecasttools::get_hubverse_table( report_epiweek_end, target_name = glue::glue("wk inc {disease_abbr} prop ed visits") ) |> + dplyr::inner_join(loc_sources, + by = "location" + ) |> dplyr::arrange( .data$target, .data$output_type, @@ -155,12 +174,6 @@ to_epiweekly_quantile_table <- function(model_batch_dir, .data$reference_date, .data$horizon, .data$output_type_id - ) |> - dplyr::mutate(other_ed_visit_forecast = ifelse( - .data$location %in% !!epiweekly_other_locations, - "direct_epiweekly_fit", - "aggregated_daily_fit" - )) - + ) return(hubverse_table) } diff --git a/hewr/tests/testthat/test_to_epiweekly_quantile_table.R b/hewr/tests/testthat/test_to_epiweekly_quantile_table.R index 15a6942b..b4600223 100644 --- a/hewr/tests/testthat/test_to_epiweekly_quantile_table.R +++ b/hewr/tests/testthat/test_to_epiweekly_quantile_table.R @@ -39,9 +39,20 @@ test_that("to_epiweekly_quantiles works as expected", { ) |> suppressMessages() expect_s3_class(result, "tbl_df") - expect_setequal(c( - "epiweek", "epiyear", "quantile_value", "quantile_level", "location" - ), colnames(result)) + checkmate::expect_names( + colnames(result), + identical.to = c( + "epiweek", + "epiyear", + "quantile_value", + "quantile_level", + "location", + "source_samples" + ) + ) + + expect_equal(draws_file_name, unique(result$source_samples)) + expect_gt(nrow(result), 0) } @@ -127,7 +138,11 @@ test_that("to_epiweekly_quantiles handles missing forecast files", { # tests for `to_epiweekly_quantile_table` -test_that("to_epiweekly_quantile_table handles multiple locations", { +test_that(paste0( + "to_epiweekly_quantile_table ", + "handles multiple locations ", + "and multiple source files" +), { batch_dir_name <- "covid-19_r_2024-12-14_f_2024-12-08_t_2024-12-14" tempdir <- withr::local_tempdir() @@ -142,6 +157,17 @@ test_that("to_epiweekly_quantile_table handles multiple locations", { if (loc != "loc3") { disease_cols <- c(disease_cols, "prop_disease_ed_visits") } + create_tidy_forecast_data( + directory = loc_dir, + filename = "epiweekly_with_epiweekly_other_samples.parquet", + date_cols = seq( + lubridate::ymd("2024-12-08"), lubridate::ymd("2024-12-14"), + by = "week" + ), + disease_cols = disease_cols, + n_draw = 25, + with_epiweek = TRUE + ) create_tidy_forecast_data( directory = loc_dir, @@ -157,7 +183,10 @@ test_that("to_epiweekly_quantile_table handles multiple locations", { }) ## should succeed despite loc3 not having valid draws with strict = FALSE - result_w_both_locations <- to_epiweekly_quantile_table(temp_batch_dir) |> + result_w_both_locations <- + to_epiweekly_quantile_table(temp_batch_dir, + epiweekly_other_locations = "loc1" + ) |> suppressMessages() ## should error if strict = TRUE because loc3 does not have @@ -168,6 +197,44 @@ test_that("to_epiweekly_quantile_table handles multiple locations", { "did not find valid draws" ) + ## should succeed with strict = TRUE if loc3 is excluded + alt_result_w_both_locations <- ( + to_epiweekly_quantile_table(temp_batch_dir, + strict = TRUE, + exclude = "loc3" + )) |> + suppressMessages() + + ## results should be equivalent for loc2, + ## but not for loc1 + expect_equal( + result_w_both_locations |> + dplyr::filter(location == "loc2"), + alt_result_w_both_locations |> + dplyr::filter(location == "loc2") + ) + + ## check that one used epiweekly + ## other for loc1 while other used + ## default, resulting in different values + loc1_a <- result_w_both_locations |> + dplyr::filter(location == "loc1") |> + dplyr::pull(.data$value) + loc1_b <- alt_result_w_both_locations |> + dplyr::filter(location == "loc1") |> + dplyr::pull(.data$value) + + ## length checks ensure that the + ## number of allowed equalities _could_ + ## be reached if the vectors were mostly + ## or entirely identical + expect_gt(length(loc1_a), 10) + expect_gt(length(loc1_b), 10) + expect_lt( + sum(loc1_a == loc1_b), + 5 + ) + expect_s3_class(result_w_both_locations, "tbl_df") expect_gt(nrow(result_w_both_locations), 0) checkmate::expect_names( @@ -181,20 +248,32 @@ test_that("to_epiweekly_quantile_table handles multiple locations", { "output_type", "output_type_id", "value", - "other_ed_visit_forecast" + "source_samples" ) ) expect_setequal( - c("loc1", "loc2"), - result_w_both_locations$location + result_w_both_locations$location, + c("loc1", "loc2") + ) + expect_setequal( + alt_result_w_both_locations$location, + c("loc1", "loc2") ) - expect_false("loc3" %in% result_w_both_locations$location) - result_w_one_location <- to_epiweekly_quantile_table( - model_batch_dir = temp_batch_dir, - exclude = "loc1" - ) |> - suppressMessages() - expect_true("loc2" %in% result_w_one_location$location) - expect_false("loc1" %in% result_w_one_location$location) + expect_setequal( + result_w_both_locations$source_samples, + c( + "epiweekly_samples", + "epiweekly_with_epiweekly_other_samples" + ) + ) + + expect_setequal( + alt_result_w_both_locations$source_samples, + "epiweekly_samples" + ) + + + expect_false("loc3" %in% result_w_both_locations$location) + expect_false("loc3" %in% alt_result_w_both_locations$location) }) diff --git a/pipelines/postprocess_forecast_batches.py b/pipelines/postprocess_forecast_batches.py index f46fa6a2..a5690ce7 100644 --- a/pipelines/postprocess_forecast_batches.py +++ b/pipelines/postprocess_forecast_batches.py @@ -15,6 +15,7 @@ from pathlib import Path import collate_plots as cp +from forecasttools.utils import ensure_listlike from pipelines.hubverse_create_observed_data_tables import ( save_observed_data_tables, @@ -28,7 +29,16 @@ def _hubverse_table_filename( return f"{report_date}-" f"{disease.lower()}-" "hubverse-table.tsv" -def create_hubverse_table(model_batch_dir_path: str | Path) -> None: +def create_hubverse_table( + model_batch_dir_path: str | Path, + locations_exclude: str | list[str] = "", + epiweekly_other_locations: str | list[str] = "", +) -> None: + logger = logging.getLogger(__name__) + + locations_exclude = ensure_listlike(locations_exclude) + epiweekly_other_locations = ensure_listlike(epiweekly_other_locations) + model_batch_dir_path = Path(model_batch_dir_path) model_batch_dir_name = model_batch_dir_path.name batch_info = parse_model_batch_dir_name(model_batch_dir_name) @@ -45,13 +55,18 @@ def create_hubverse_table(model_batch_dir_path: str | Path) -> None: "pipelines/hubverse_create_table.R", f"{model_batch_dir_path}", f"{output_path}", + "--exclude", + f"{' '.join(locations_exclude)}", + "--epiweekly-other-locations", + f"{' '.join(epiweekly_other_locations)}", ], capture_output=True, ) if result.returncode != 0: raise RuntimeError( - "create_hubverse_table: " f"{result.stdout}\n" f"{result.stderr}" + f"create_hubverse_table: {result.stdout}\n {result.stderr}" ) + return None @@ -86,10 +101,16 @@ def create_pointinterval_plot(model_batch_dir_path: Path | str) -> None: def process_model_batch_dir( - model_batch_dir_path: Path, plot_ext: str = "pdf" + model_batch_dir_path: Path, + locations_exclude: str | list[str] = "", + epiweekly_other_locations: str | list[str] = "", + plot_ext: str = "pdf", ) -> None: + locations_exclude = ensure_listlike(locations_exclude) + epiweekly_other_locations = ensure_listlike(epiweekly_other_locations) + plot_types = ["Disease", "Other", "prop_disease_ed_visits"] - plot_timescales = ["daily", "epiweekly", "epiweekly_other"] + plot_timescales = ["daily", "epiweekly", "epiweekly_with_epiweekly_other"] plot_yscales = ["", "log_"] plots_to_collate = [ @@ -97,14 +118,24 @@ def process_model_batch_dir( for p_type, p_yscale, p_timescale in itertools.product( plot_types, plot_yscales, plot_timescales ) - if not (p_type == "Disease" and p_timescale == "epiweekly_other") + if not ( + p_type == "Disease" + and p_timescale == "epiweekly_with_epiweekly_other" + ) ] logger = logging.getLogger(__name__) logger.info("Collating plots...") cp.process_dir(model_batch_dir_path, target_filenames=plots_to_collate) logger.info("Creating hubverse table...") - create_hubverse_table(model_batch_dir_path) + logger.info( + "Using epiweekly other forecast for " f"{epiweekly_other_locations}..." + ) + create_hubverse_table( + model_batch_dir_path, + locations_exclude=locations_exclude, + epiweekly_other_locations=epiweekly_other_locations, + ) logger.info("Creating pointinterval plot...") create_pointinterval_plot(model_batch_dir_path) @@ -112,15 +143,27 @@ def process_model_batch_dir( def main( base_forecast_dir: Path | str, path_to_latest_data: Path | str, - diseases: list[str] = ["COVID-19", "Influenza"], + diseases: list[str] = None, + locations_exclude: str | list[str] = "", + epiweekly_other_locations: str | list[str] = "", ) -> None: + if diseases is None: + diseases = ["COVID-19", "Influenza"] + diseases = ensure_listlike(diseases) + locations_exclude = ensure_listlike(locations_exclude) + epiweekly_other_locations = ensure_listlike(epiweekly_other_locations) + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) to_process = get_all_forecast_dirs(base_forecast_dir, diseases) for batch_dir in to_process: logger.info(f"Processing {batch_dir}...") model_batch_dir_path = Path(base_forecast_dir, batch_dir) - process_model_batch_dir(model_batch_dir_path) + process_model_batch_dir( + model_batch_dir_path, + locations_exclude=locations_exclude, + epiweekly_other_locations=epiweekly_other_locations, + ) logger.info(f"Finished processing {batch_dir}") logger.info("Created observed data tables for visualization...") save_observed_data_tables( @@ -157,7 +200,30 @@ def main( "Default 'COVID-19 Influenza' (i.e. postprocess both)." ), ) + parser.add_argument( + "--locations-exclude", + type=str, + default="", + help=( + "Name(s) of locations to exclude from the hubverse table, " + "as a whitespace-separated string of two-letter location " + "codes." + ), + ) + parser.add_argument( + "--epiweekly-other-locations", + type=str, + default="", + help=( + "Name(s) of locations for which to use an explicitly epiweekly " + "forecast of other (non-target) ED visits, as opposed to a " + "daily forecast aggregated to epiweekly. Locations should be " + "specified as a whitespace-separated string of two-letter codes." + ), + ) args = parser.parse_args() args.diseases = args.diseases.split() + args.locations_exclude = args.locations_exclude.split() + args.epiweekly_other_locations = args.epiweekly_other_locations.split() main(**vars(args))