Skip to content

Commit

Permalink
separate functionality in process_state_forecast
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer committed Feb 5, 2025
1 parent 955d136 commit dec6999
Showing 1 changed file with 120 additions and 91 deletions.
211 changes: 120 additions & 91 deletions hewr/R/process_state_forecast.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,110 @@
process_timeseries <- function(timeseries_model_dir,
daily_samples,
epiweekly_samples,
daily_training_dat,
epiweekly_training_dat,
required_columns) {
# augment daily and epiweekly other ed visits forecast
# with "sample" format observed data

## ts model, daily denominator
daily_ts_denom_samples <- arrow::read_parquet(
fs::path(timeseries_model_dir,
"daily_baseline_ts_forecast_samples",
ext = "parquet"
)
) |>
dplyr::filter(.data$.variable == "other_ed_visits") |>
to_tidy_draws_timeseries(
observed = daily_training_dat |>
dplyr::filter(.data$.variable == "other_ed_visits") |>
dplyr::select(-"data_type"),
epiweekly = FALSE
) |>
dplyr::select(tidyselect::any_of(required_columns))

Check warning on line 24 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L11-L24

Added lines #L11 - L24 were not covered by tests

## ts model, daily denominator aggregated to epiweekly
agg_ewkly_ts_denom_samples <-
daily_ts_denom_samples |>
forecasttools::daily_to_epiweekly(
value_col = ".value",
weekly_value_name = ".value",
id_cols = c(".draw", "geo_value", "disease", ".variable"),
strict = TRUE,
with_epiweek_end_date = TRUE,
epiweek_end_date_name = "date"
) |>
dplyr::select(tidyselect::any_of(required_columns))

Check warning on line 37 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L27-L37

Added lines #L27 - L37 were not covered by tests

## ts model, epiweekly denominator
ewkly_ts_denom_samples <- arrow::read_parquet(
fs::path(timeseries_model_dir,
"epiweekly_baseline_ts_forecast_samples",
ext = "parquet"
)
) |>
dplyr::filter(.data$.variable == "other_ed_visits") |>
to_tidy_draws_timeseries(
observed = epiweekly_training_dat |>
dplyr::filter(.data$.variable == "other_ed_visits") |>
dplyr::select(-"data_type"),
epiweekly = TRUE
) |>
dplyr::select(tidyselect::any_of(required_columns))

Check warning on line 53 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L40-L53

Added lines #L40 - L53 were not covered by tests

# Daily Numerator, Daily Denominator
daily_samples_daily_n_daily_d <- join_and_calc_prop(
daily_samples,
daily_ts_denom_samples
)

Check warning on line 59 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L56-L59

Added lines #L56 - L59 were not covered by tests

# Epiweekly Aggregated Numerator, Epiweekly Aggregated Denominator
ewkly_samples_agg_n_agg_d <- join_and_calc_prop(
epiweekly_samples,
agg_ewkly_ts_denom_samples
)

Check warning on line 65 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L62-L65

Added lines #L62 - L65 were not covered by tests

# Epiweekly Aggregated Numerator, Epiweekly Denominator
ewkly_samples_agg_n_ewkly_d <- join_and_calc_prop(
epiweekly_samples,
ewkly_ts_denom_samples
)

Check warning on line 71 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L68-L71

Added lines #L68 - L71 were not covered by tests



list(
"daily_samples" = daily_samples_daily_n_daily_d,
"epiweekly_samples" = ewkly_samples_agg_n_agg_d,
"epiweekly_with_epiweekly_other_samples" = ewkly_samples_agg_n_ewkly_d

Check warning on line 78 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L75-L78

Added lines #L75 - L78 were not covered by tests
)
}

epiweekly_samples_from_daily <- function(daily_samples, required_columns) {
epiweekly_obs_ed_samples <-
daily_samples |>
dplyr::filter(.data$.variable == "observed_ed_visits") |>
forecasttools::daily_to_epiweekly(
value_col = ".value",
weekly_value_name = ".value",
id_cols = c(
".chain", ".iteration", ".draw", "geo_value", "disease",
".variable"
),
strict = TRUE,
with_epiweek_end_date = TRUE,
epiweek_end_date_name = "date"
) |>
dplyr::select(tidyselect::all_of(required_columns))

Check warning on line 97 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L83-L97

Added lines #L83 - L97 were not covered by tests

epiweekly_samples <-
daily_samples |>
dplyr::filter(.data$.variable != "observed_ed_visits") |>
dplyr::bind_rows(epiweekly_obs_ed_samples) |>
dplyr::select(tidyselect::all_of(required_columns))

Check warning on line 103 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L99-L103

Added lines #L99 - L103 were not covered by tests

return(epiweekly_samples)

Check warning on line 105 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L105

Added line #L105 was not covered by tests
}

#' Combine training and evaluation data for
#' postprocessing.
#'
Expand Down Expand Up @@ -282,30 +389,12 @@ process_state_forecast <- function(model_run_dir,

samples_list <- list(daily_samples = daily_samples)

Check warning on line 390 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L390

Added line #L390 was not covered by tests

# For the E model, do epiweekly
# For the E model, do epiweekly and process denominator
if (pyrenew_model_components["e"]) {
epiweekly_obs_ed_samples <-
daily_samples |>
dplyr::filter(.data$.variable == "observed_ed_visits") |>
forecasttools::daily_to_epiweekly(
value_col = ".value",
weekly_value_name = ".value",
id_cols = c(
".chain", ".iteration", ".draw", "geo_value", "disease",
".variable"
),
strict = TRUE,
with_epiweek_end_date = TRUE,
epiweek_end_date_name = "date"
) |>
dplyr::select(tidyselect::all_of(required_columns))

epiweekly_samples <-
daily_samples |>
dplyr::filter(.data$.variable != "observed_ed_visits") |>
dplyr::bind_rows(epiweekly_obs_ed_samples) |>
dplyr::select(tidyselect::all_of(required_columns))

epiweekly_samples <- epiweekly_samples_from_daily(
daily_samples,
required_columns

Check warning on line 396 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L393-L396

Added lines #L393 - L396 were not covered by tests
)
samples_list$epiweekly_samples <- epiweekly_samples

Check warning on line 398 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L398

Added line #L398 was not covered by tests

## Process timeseries posterior
Expand All @@ -315,79 +404,19 @@ process_state_forecast <- function(model_run_dir,
timeseries_model_name
)

Check warning on line 405 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L401-L405

Added lines #L401 - L405 were not covered by tests

# augment daily and epiweekly other ed visits forecast
# with "sample" format observed data

## ts model, daily denominator
daily_ts_denom_samples <- arrow::read_parquet(
fs::path(timeseries_model_dir,
"daily_baseline_ts_forecast_samples",
ext = "parquet"
)
) |>
dplyr::filter(.data$.variable == "other_ed_visits") |>
to_tidy_draws_timeseries(
observed = daily_training_dat |>
dplyr::filter(.data$.variable == "other_ed_visits") |>
dplyr::select(-"data_type"),
epiweekly = FALSE
) |>
dplyr::select(tidyselect::any_of(required_columns))

## ts model, daily denominator aggregated to epiweekly
agg_ewkly_ts_denom_samples <-
daily_ts_denom_samples |>
forecasttools::daily_to_epiweekly(
value_col = ".value",
weekly_value_name = ".value",
id_cols = c(".draw", "geo_value", "disease", ".variable"),
strict = TRUE,
with_epiweek_end_date = TRUE,
epiweek_end_date_name = "date"
) |>
dplyr::select(tidyselect::any_of(required_columns))

## ts model, epiweekly denominator
ewkly_ts_denom_samples <- arrow::read_parquet(
fs::path(timeseries_model_dir,
"epiweekly_baseline_ts_forecast_samples",
ext = "parquet"
)
) |>
dplyr::filter(.data$.variable == "other_ed_visits") |>
to_tidy_draws_timeseries(
observed = epiweekly_training_dat |>
dplyr::filter(.data$.variable == "other_ed_visits") |>
dplyr::select(-"data_type"),
epiweekly = TRUE
) |>
dplyr::select(tidyselect::any_of(required_columns))

# Daily Numerator, Daily Denominator
daily_samples_daily_n_daily_d <- join_and_calc_prop(
timeseries_output <- process_timeseries(
timeseries_model_dir,
daily_samples,
daily_ts_denom_samples
)

samples_list$daily_samples <- daily_samples_daily_n_daily_d

# Epiweekly Aggregated Numerator, Epiweekly Aggregated Denominator
ewkly_samples_agg_n_agg_d <- join_and_calc_prop(
epiweekly_samples,
agg_ewkly_ts_denom_samples
)

samples_list$epiweekly_samples <- ewkly_samples_agg_n_agg_d


# Epiweekly Aggregated Numerator, Epiweekly Denominator
ewkly_samples_agg_n_ewkly_d <- join_and_calc_prop(
epiweekly_samples,
ewkly_ts_denom_samples
daily_training_dat,
epiweekly_training_dat,
required_columns
)

Check warning on line 414 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L407-L414

Added lines #L407 - L414 were not covered by tests

samples_list$daily_samples <- timeseries_output$daily_samples
samples_list$epiweekly_samples <- timeseries_output$epiweekly_samples
samples_list$epiweekly_with_epiweekly_other_samples <-
ewkly_samples_agg_n_ewkly_d
timeseries_output$epiweekly_with_epiweekly_other_samples

Check warning on line 419 in hewr/R/process_state_forecast.R

View check run for this annotation

Codecov / codecov/patch

hewr/R/process_state_forecast.R#L416-L419

Added lines #L416 - L419 were not covered by tests
}
}

Expand Down

0 comments on commit dec6999

Please sign in to comment.