-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b2b4aa7
commit d4e1984
Showing
2 changed files
with
608 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,386 @@ | ||
--- | ||
title: "Replicating Hospital Only Model from `cdcgov/wastewater-informed-covid-forecasting`" | ||
format: gfm | ||
engine: jupyter | ||
--- | ||
|
||
```{python} | ||
# | label: setup | ||
import json | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import numpyro | ||
import numpyro.distributions as dist | ||
import numpyro.distributions.transforms as transforms | ||
from pyrenew.deterministic import DeterministicVariable | ||
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable | ||
from pyrenew.model import hosp_only_ww_model | ||
numpyro.set_host_device_count(1) | ||
# model crashes if run in parallel | ||
# see https://github.com/pyro-ppl/numpyro/issues/1836 | ||
``` | ||
|
||
## Background | ||
|
||
This tutorial provides a demonstration of our reimplementation of "Model 2" from the `wastewater-informed-covid-forecasting` project. | ||
The model is described [here](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/model_definition.md). | ||
Stan code for the model is [here](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/cfaforecastrenewalww/inst/stan/renewal_ww_hosp_site_level_inf_dynamics.stan). | ||
|
||
The model we provide is designed to be fully-compatible with the stan_data generated in the that project. | ||
We provide the stan data used in the `toy_data_vignette` [vignette](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/cfaforecastrenewalww/vignettes/toy_data_vignette.Rmd) in the `wastewater-informed-covid-forecasting` project. | ||
The data is available in `scratch/stan_data_hosp_only.json`. | ||
This data was generated by running `scratch/save_from_vignette.R` after running all the cells in the vignette. | ||
This script also saves the posterior samples from the model for comparison to our own model. | ||
|
||
## Load Data and create Priors | ||
|
||
We begin by loading the stan_data and converting it to priors used in our model. | ||
```{python} | ||
# | label: Load data and create priors | ||
# | code-fold: true | ||
def convert_to_logmean_log_sd(mean, sd): | ||
logmean = np.log( | ||
np.power(mean, 2) / np.sqrt(np.power(sd, 2) + np.power(mean, 2)) | ||
) | ||
logsd = np.sqrt(np.log(1 + (np.power(sd, 2) / np.power(mean, 2)))) | ||
return logmean, logsd | ||
# Load the JSON file | ||
import builtins | ||
with builtins.open( | ||
"../../../scratch/stan_data_hosp_only.json", | ||
"r", | ||
) as file: | ||
stan_data = json.load(file) | ||
i0_over_n_prior_a = stan_data["i0_over_n_prior_a"][0] | ||
i0_over_n_prior_b = stan_data["i0_over_n_prior_b"][0] | ||
i0_over_n_rv = DistributionalRV( | ||
"i0_over_n_rv", dist.Beta(i0_over_n_prior_a, i0_over_n_prior_b) | ||
) | ||
initial_growth_prior_mean = stan_data["initial_growth_prior_mean"][0] | ||
initial_growth_prior_sd = stan_data["initial_growth_prior_sd"][0] | ||
initialization_rate_rv = DistributionalRV( | ||
"rate", | ||
dist.TruncatedNormal( | ||
loc=initial_growth_prior_mean, | ||
scale=initial_growth_prior_sd, | ||
low=-1, | ||
high=1, | ||
), | ||
) | ||
# could reasonably switch to non-Truncated | ||
r_prior_mean = stan_data["r_prior_mean"][0] | ||
r_prior_sd = stan_data["r_prior_sd"][0] | ||
r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd) | ||
log_r_mu_intercept_rv = DistributionalRV( | ||
"log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd) | ||
) | ||
eta_sd_sd = stan_data["eta_sd_sd"][0] | ||
eta_sd_rv = DistributionalRV( | ||
"eta_sd", dist.TruncatedNormal(0, eta_sd_sd, low=0) | ||
) | ||
autoreg_rt_a = stan_data["autoreg_rt_a"][0] | ||
autoreg_rt_b = stan_data["autoreg_rt_b"][0] | ||
autoreg_rt_rv = DistributionalRV( | ||
"autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b) | ||
) | ||
generation_interval_pmf_rv = DeterministicVariable( | ||
"generation_interval_pmf", jnp.array(stan_data["generation_interval"]) | ||
) | ||
infection_feedback_pmf_rv = DeterministicVariable( | ||
"infection_feedback_pmf", jnp.array(stan_data["infection_feedback_pmf"]) | ||
) | ||
inf_feedback_prior_logmean = stan_data["inf_feedback_prior_logmean"][0] | ||
inf_feedback_prior_logsd = stan_data["inf_feedback_prior_logsd"][0] | ||
inf_feedback_strength_rv = TransformedRandomVariable( | ||
"inf_feedback", | ||
DistributionalRV( | ||
"inf_feedback_raw", | ||
dist.LogNormal(inf_feedback_prior_logmean, inf_feedback_prior_logsd), | ||
), | ||
transforms=transforms.AffineTransform(loc=0, scale=-1), | ||
) | ||
# Could be reparameterized? | ||
p_hosp_prior_mean = stan_data["p_hosp_prior_mean"][0] | ||
p_hosp_sd_logit = stan_data["p_hosp_sd_logit"][0] | ||
p_hosp_mean_rv = DistributionalRV( | ||
"p_hosp_mean", | ||
dist.Normal(transforms.logit(p_hosp_prior_mean), p_hosp_sd_logit), | ||
) # logit scale | ||
p_hosp_w_sd_sd = stan_data["p_hosp_w_sd_sd"][0] | ||
p_hosp_w_sd_rv = DistributionalRV( | ||
"p_hosp_w_sd_sd", dist.TruncatedNormal(0, p_hosp_w_sd_sd, low=0) | ||
) | ||
autoreg_p_hosp_a = stan_data["autoreg_p_hosp_a"][0] | ||
autoreg_p_hosp_b = stan_data["autoreg_p_hosp_b"][0] | ||
autoreg_p_hosp_rv = DistributionalRV( | ||
"autoreg_p_hosp", dist.Beta(autoreg_p_hosp_a, autoreg_p_hosp_b) | ||
) | ||
# hosp_wday_effect ~ normal(effect_mean, wday_effect_prior_sd); | ||
# wday_effect_prior_mean = stan_data["wday_effect_prior_mean"][0] | ||
# wday_effect_prior_sd = stan_data["wday_effect_prior_sd"][0] | ||
# Instead of the above, use a Dirichlet prior (see https://github.com/CDCgov/ww-inference-model/issues/42) | ||
hosp_wday_effect_rv = TransformedRandomVariable( | ||
"hosp_wday_effect", | ||
DistributionalRV( | ||
"hosp_wday_effect_raw", dist.Dirichlet(concentration=jnp.ones(7)) | ||
), | ||
transforms.AffineTransform(loc=0, scale=7), | ||
) | ||
inf_to_hosp_rv = DeterministicVariable( | ||
"inf_to_hosp", jnp.array(stan_data["inf_to_hosp"]) | ||
) | ||
inv_sqrt_phi_prior_mean = stan_data["inv_sqrt_phi_prior_mean"][0] | ||
inv_sqrt_phi_prior_sd = stan_data["inv_sqrt_phi_prior_sd"][0] | ||
phi_rv = TransformedRandomVariable( | ||
"phi", | ||
DistributionalRV( | ||
"inv_sqrt_phi", | ||
dist.TruncatedNormal( | ||
loc=inv_sqrt_phi_prior_mean, | ||
scale=inv_sqrt_phi_prior_sd, | ||
low=1 / jnp.sqrt(5000), | ||
), | ||
), | ||
transforms=transforms.PowerTransform(-2), | ||
) | ||
uot = stan_data["uot"][0] | ||
state_pop = stan_data["state_pop"][0] | ||
data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) | ||
``` | ||
|
||
# Simulate from the model | ||
|
||
Next, we define the model: | ||
|
||
```{python} | ||
# | label: define the model | ||
my_model = hosp_only_ww_model( | ||
state_pop=state_pop, | ||
i0_over_n_rv=i0_over_n_rv, | ||
initialization_rate_rv=initialization_rate_rv, | ||
log_r_mu_intercept_rv=log_r_mu_intercept_rv, | ||
autoreg_rt_rv=autoreg_rt_rv, # ar process | ||
eta_sd_rv=eta_sd_rv, # sd of random walk for ar process, | ||
generation_interval_pmf_rv=generation_interval_pmf_rv, | ||
infection_feedback_pmf_rv=infection_feedback_pmf_rv, | ||
infection_feedback_strength_rv=inf_feedback_strength_rv, | ||
p_hosp_mean_rv=p_hosp_mean_rv, | ||
p_hosp_w_sd_rv=p_hosp_w_sd_rv, | ||
autoreg_p_hosp_rv=autoreg_p_hosp_rv, | ||
hosp_wday_effect_rv=hosp_wday_effect_rv, | ||
phi_rv=phi_rv, | ||
inf_to_hosp_rv=inf_to_hosp_rv, | ||
# n_initialization_points=uot, | ||
n_initialization_points=len(jnp.array(stan_data["inf_to_hosp"])), | ||
# i0_t_offset=-50, # to match stan model | ||
i0_t_offset=0, # a better way of parameterizing | ||
) | ||
``` | ||
|
||
|
||
We check that we can simulate from the prior predictive | ||
```{python} | ||
# | label: prior predictive | ||
# | eval: false | ||
# for some reason the posterior inference crashes if we do the prior predictive first | ||
prior_predictive = my_model.prior_predictive( | ||
n_datapoints=len(data_observed_hospital_admissions), | ||
numpyro_predictive_args={"num_samples": 200}, | ||
) | ||
``` | ||
|
||
# Fit the model | ||
|
||
Now we can fit the model to the observed data: | ||
```{python} | ||
# | label: fit the model | ||
my_model.run( | ||
num_warmup=500, | ||
num_samples=500, | ||
rng_key=jax.random.key(200), | ||
data_observed_hospital_admissions=data_observed_hospital_admissions, | ||
mcmc_args=dict(num_chains=4), | ||
) | ||
``` | ||
|
||
Check the posterior predictive: | ||
|
||
```{python} | ||
# | label: posterior predictive | ||
posterior_predictive = my_model.posterior_predictive( | ||
n_datapoints=len(data_observed_hospital_admissions) | ||
) | ||
``` | ||
|
||
Forecasting is broken (dependent on https://github.com/CDCgov/multisignal-epi-inference/issues/328) | ||
|
||
```{python} | ||
# | label: posterior forecast | ||
# | eval: false | ||
my_model.posterior_predictive( | ||
n_datapoints=len(data_observed_hospital_admissions) + 2 | ||
) | ||
``` | ||
|
||
|
||
## Prepare for plotting | ||
```{python} | ||
import arviz as az | ||
idata = az.from_numpyro( | ||
my_model.mcmc, posterior_predictive=posterior_predictive | ||
) | ||
``` | ||
|
||
```{python} | ||
def compute_eti(dataset, eti_prob): | ||
eti_bdry = dataset.quantile( | ||
((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw") | ||
) | ||
return eti_bdry.values.T | ||
``` | ||
|
||
|
||
```{python} | ||
import matplotlib.pyplot as plt | ||
def plot_posterior(name, predictive=False): | ||
if predictive: | ||
posterior_object = idata.posterior_predictive | ||
else: | ||
posterior_object = idata.posterior | ||
x_data = posterior_object[f"{name}_dim_0"] | ||
y_data = posterior_object[name] | ||
fig, axes = plt.subplots(figsize=(6, 5)) | ||
az.plot_hdi( | ||
x_data, | ||
hdi_data=compute_eti(y_data, 0.9), | ||
color="C0", | ||
smooth=False, | ||
fill_kwargs={"alpha": 0.3}, | ||
ax=axes, | ||
) | ||
az.plot_hdi( | ||
x_data, | ||
hdi_data=compute_eti(y_data, 0.5), | ||
color="C0", | ||
smooth=False, | ||
fill_kwargs={"alpha": 0.6}, | ||
ax=axes, | ||
) | ||
# Add median of the posterior to the figure | ||
median_ts = y_data.median(dim=["chain", "draw"]) | ||
plt.plot( | ||
x_data, | ||
median_ts, | ||
color="C0", | ||
label="Median", | ||
) | ||
axes.legend() | ||
axes.set_title(name, fontsize=10) | ||
axes.set_xlabel("Time", fontsize=10) | ||
axes.set_ylabel(name, fontsize=10) | ||
return fig | ||
``` | ||
|
||
## Plot all posteriors | ||
# Do we know why some univariate sites have a dimension and others do not? | ||
```{python} | ||
for key in list(idata.posterior.keys()): | ||
try: | ||
fig = plot_posterior(key) | ||
fig.show() | ||
except Exception as e: | ||
print(f"An error occurred while plotting {key}: {e}") | ||
``` | ||
|
||
## Posterior predictive hospital admissions | ||
```{python} | ||
import matplotlib.pyplot as plt | ||
x_data = idata.posterior_predictive["observed_hospital_admissions_dim_0"] + uot | ||
y_data = idata.posterior_predictive["observed_hospital_admissions"] | ||
fig, axes = plt.subplots(figsize=(6, 5)) | ||
az.plot_hdi( | ||
x_data, | ||
hdi_data=compute_eti(y_data, 0.9), | ||
color="C0", | ||
smooth=False, | ||
fill_kwargs={"alpha": 0.3}, | ||
ax=axes, | ||
) | ||
az.plot_hdi( | ||
x_data, | ||
hdi_data=compute_eti(y_data, 0.5), | ||
color="C0", | ||
smooth=False, | ||
fill_kwargs={"alpha": 0.6}, | ||
ax=axes, | ||
) | ||
# Add median of the posterior to the figure | ||
median_ts = y_data.median(dim=["chain", "draw"]) | ||
plt.plot( | ||
x_data, | ||
median_ts, | ||
color="C0", | ||
label="Median", | ||
) | ||
plt.scatter( | ||
idata.observed_data["observed_hospital_admissions_dim_0"] + uot, | ||
idata.observed_data["observed_hospital_admissions"], | ||
color="black", | ||
) | ||
axes.legend() | ||
axes.set_title("Posterior Predictive Admissions", fontsize=10) | ||
axes.set_xlabel("Time", fontsize=10) | ||
axes.set_ylabel("Hospital Admissions", fontsize=10) | ||
plt.show() | ||
``` | ||
|
||
|
||
|
||
```{python} | ||
idata.posterior["log_r_mu_intercept_rv"] | ||
az.summary( | ||
idata, | ||
var_names=["log_rt", "periodic_diff_sd", "autoreg_rt_det", "rtu"], | ||
stat_focus="median", | ||
) | ||
``` | ||
Why is log_rt 2 dimensional? |
Oops, something went wrong.