Skip to content

Commit

Permalink
Denominator Forecasting (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer authored Oct 25, 2024
1 parent 7253f9b commit ce71163
Show file tree
Hide file tree
Showing 9 changed files with 505 additions and 112 deletions.
8 changes: 4 additions & 4 deletions notebooks/hosp_only_ww_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ We begin by loading the Stan data, converting it the correct inputs for our mode

```{python}
# | label: create model
my_hosp_only_ww_model, data_observed_hospital_admissions = (
my_hosp_only_ww_model, data_observed_disease_hospital_admissions = (
create_hosp_only_ww_model_from_stan_data(
"data/fit_hosp_only/stan_data.json"
)
Expand All @@ -50,7 +50,7 @@ We check that we can simulate from the prior predictive
n_forecast_days = 35
prior_predictive = my_hosp_only_ww_model.prior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days,
n_datapoints=len(data_observed_disease_hospital_admissions) + n_forecast_days,
numpyro_predictive_args={"num_samples": 200},
)
```
Expand All @@ -64,7 +64,7 @@ my_hosp_only_ww_model.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.key(200),
data_observed_hospital_admissions=data_observed_hospital_admissions,
data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions,
mcmc_args=dict(num_chains=4, progress_bar=False),
nuts_args=dict(find_heuristic_step_size=True),
)
Expand All @@ -75,7 +75,7 @@ Create the posterior predictive and forecast:
```{python}
# | label: posterior predictive
posterior_predictive = my_hosp_only_ww_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days
n_datapoints=len(data_observed_disease_hospital_admissions) + n_forecast_days
)
```

Expand Down
10 changes: 7 additions & 3 deletions nssp_demo/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def build_model_from_dir(model_dir):
jnp.array(model_data["generation_interval_pmf"]),
) # check if off by 1 or reversed

data_observed_hospital_admissions = jnp.array(
model_data["data_observed_hospital_admissions"]
data_observed_disease_hospital_admissions = jnp.array(
model_data["data_observed_disease_hospital_admissions"]
)
state_pop = jnp.array(model_data["state_pop"])

Expand Down Expand Up @@ -84,4 +84,8 @@ def build_model_from_dir(model_dir):
n_initialization_points=uot,
)

return my_model, data_observed_hospital_admissions, right_truncation_offset
return (
my_model,
data_observed_disease_hospital_admissions,
right_truncation_offset,
)
10 changes: 6 additions & 4 deletions nssp_demo/fit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
args = parser.parse_args()
model_dir = args.model_dir

my_model, data_observed_hospital_admissions, right_truncation_offset = (
build_model_from_dir(model_dir)
)
(
my_model,
data_observed_disease_hospital_admissions,
right_truncation_offset,
) = build_model_from_dir(model_dir)
my_model.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.key(200),
data_observed_hospital_admissions=data_observed_hospital_admissions,
data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions,
right_truncation_offset=right_truncation_offset,
mcmc_args=dict(num_chains=n_chains, progress_bar=True),
nuts_args=dict(find_heuristic_step_size=True),
Expand Down
13 changes: 8 additions & 5 deletions nssp_demo/generate_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
args = parser.parse_args()
model_dir = args.model_dir
n_forecast_points = args.n_forecast_points
my_model, data_observed_hospital_admissions, right_truncation_offset = (
build_model_from_dir(model_dir)
)
(
my_model,
data_observed_disease_hospital_admissions,
right_truncation_offset,
) = build_model_from_dir(model_dir)

my_model._init_model(1, 1)
fresh_sampler = my_model.mcmc.sampler
Expand All @@ -47,12 +49,13 @@
# "num_samples": my_model.mcmc.num_samples * my_model.mcmc.num_chains,
# "batch_ndims":1
# },
# n_datapoints=len(data_observed_hospital_admissions) + n_forecast_points,
# n_datapoints=len(data_observed_disease_hospital_admissions) + n_forecast_points,
# )
# need to figure out a way to generate these as distinct chains, so that the result of the to_datarame method is more compact

posterior_predictive = my_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + n_forecast_points
n_datapoints=len(data_observed_disease_hospital_admissions)
+ n_forecast_points
)

idata = az.from_numpyro(
Expand Down
101 changes: 101 additions & 0 deletions nssp_demo/other_ed_admissions_forecast.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
library(tidyverse)
library(fs)
library(fable)
library(jsonlite)
library(glue)
library(argparser)
library(arrow)

p <- arg_parser("Forecast other (non-target-disease) ED admissions") %>%
add_argument(p, "--model_dir",
help = "Directory containing the model data",
required = TRUE
) %>%
add_argument("--n_forecast_days",
help = "Number of days to forecast",
default = 28L
) %>%
add_argument("--n_samples",
help = "Number of samples to generate",
default = 2000L
)

argv <- parse_args(p)
model_dir <- path(argv$model_dir)
n_forecast_days <- argv$n_forecast_days
n_samples <- arv$n_samples

disease_name_nssp_map <- c(
"covid-19" = "COVID-19/Omicron",
"influenza" = "Influenza"
)

base_dir <- path_dir(model_dir)

disease_name_raw <- base_dir %>%
path_file() %>%
str_extract("^.+(?=_r_)")

disease_name_nssp <- unname(disease_name_nssp_map[disease_name_raw])

fit_and_forecast <- function(other_data,
n_forecast_days = 28,
n_samples = 2000) {
forecast_horizon <- glue("{n_forecast_days} days")

fit <-
other_data %>%
filter(data_type == "train") %>%
model(
comb_model = combination_ensemble(
ETS(log(ED_admissions) ~ trend(method = c("N", "M", "A"))),
ARIMA(log(ED_admissions))
),
arima = ARIMA(log(ED_admissions)),
ets = ETS(log(ED_admissions) ~ trend(method = c("N", "M", "A")))
)

forecast_samples <- fit |>
generate(h = forecast_horizon, times = n_samples) |>
as_tibble() %>%
mutate(ED_admissions = .sim, .draw = as.integer(.rep)) |>
filter(.model == "comb_model") %>%
select(date, .draw, other_ED_admissions = ED_admissions)

forecast_samples
}

main <- function(model_dir, n_forecast_days = 28, n_samples = 2000) {
# to do: do this with json data that has dates
data_path <- path(model_dir, "data", ext = "csv")

other_data <- read_csv(data_path) %>%
mutate(disease = if_else(disease == disease_name_nssp,
"Disease", disease
)) %>%
pivot_wider(names_from = disease, values_from = ED_admissions) %>%
mutate(Other = Total - Disease) %>%
select(date, ED_admissions = Other, data_type) %>%
as_tsibble(index = date)

forecast_samples <- fit_and_forecast(other_data, n_forecast_days, n_samples)

save_path <- path(model_dir, "other_ed_admissions_forecast", ext = "parquet")
write_parquet(forecast_samples, save_path)
}


main(model_dir, n_forecast_days, n_samples)
# File will end here once command line version is working
# Temp code to run for all states while command line version doesn't work
# Command line version is dependent on https://github.com/rstudio/renv/pull/2018

base_dir <- path(
"nssp_demo/private_data/influenza_r_2024-10-21_f_2024-07-16_t_2024-10-13"
)

dir_ls(base_dir, type = "dir") %>%
walk(.f = function(model_dir) {
print(path_file(model_dir))
main(model_dir, n_forecast_days = 28, n_samples = 2000)
})
Loading

0 comments on commit ce71163

Please sign in to comment.