Skip to content

Commit

Permalink
add files from pyrenew PR
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer committed Aug 22, 2024
1 parent b2b4aa7 commit d4e1984
Show file tree
Hide file tree
Showing 2 changed files with 608 additions and 0 deletions.
386 changes: 386 additions & 0 deletions notebooks/hosp_only_ww_model.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,386 @@
---
title: "Replicating Hospital Only Model from `cdcgov/wastewater-informed-covid-forecasting`"
format: gfm
engine: jupyter
---

```{python}
# | label: setup
import json
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.transforms as transforms
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.model import hosp_only_ww_model
numpyro.set_host_device_count(1)
# model crashes if run in parallel
# see https://github.com/pyro-ppl/numpyro/issues/1836
```

## Background

This tutorial provides a demonstration of our reimplementation of "Model 2" from the `wastewater-informed-covid-forecasting` project.
The model is described [here](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/model_definition.md).
Stan code for the model is [here](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/cfaforecastrenewalww/inst/stan/renewal_ww_hosp_site_level_inf_dynamics.stan).

The model we provide is designed to be fully-compatible with the stan_data generated in the that project.
We provide the stan data used in the `toy_data_vignette` [vignette](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/cfaforecastrenewalww/vignettes/toy_data_vignette.Rmd) in the `wastewater-informed-covid-forecasting` project.
The data is available in `scratch/stan_data_hosp_only.json`.
This data was generated by running `scratch/save_from_vignette.R` after running all the cells in the vignette.
This script also saves the posterior samples from the model for comparison to our own model.

## Load Data and create Priors

We begin by loading the stan_data and converting it to priors used in our model.
```{python}
# | label: Load data and create priors
# | code-fold: true
def convert_to_logmean_log_sd(mean, sd):
logmean = np.log(
np.power(mean, 2) / np.sqrt(np.power(sd, 2) + np.power(mean, 2))
)
logsd = np.sqrt(np.log(1 + (np.power(sd, 2) / np.power(mean, 2))))
return logmean, logsd
# Load the JSON file
import builtins
with builtins.open(
"../../../scratch/stan_data_hosp_only.json",
"r",
) as file:
stan_data = json.load(file)
i0_over_n_prior_a = stan_data["i0_over_n_prior_a"][0]
i0_over_n_prior_b = stan_data["i0_over_n_prior_b"][0]
i0_over_n_rv = DistributionalRV(
"i0_over_n_rv", dist.Beta(i0_over_n_prior_a, i0_over_n_prior_b)
)
initial_growth_prior_mean = stan_data["initial_growth_prior_mean"][0]
initial_growth_prior_sd = stan_data["initial_growth_prior_sd"][0]
initialization_rate_rv = DistributionalRV(
"rate",
dist.TruncatedNormal(
loc=initial_growth_prior_mean,
scale=initial_growth_prior_sd,
low=-1,
high=1,
),
)
# could reasonably switch to non-Truncated
r_prior_mean = stan_data["r_prior_mean"][0]
r_prior_sd = stan_data["r_prior_sd"][0]
r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd)
log_r_mu_intercept_rv = DistributionalRV(
"log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd)
)
eta_sd_sd = stan_data["eta_sd_sd"][0]
eta_sd_rv = DistributionalRV(
"eta_sd", dist.TruncatedNormal(0, eta_sd_sd, low=0)
)
autoreg_rt_a = stan_data["autoreg_rt_a"][0]
autoreg_rt_b = stan_data["autoreg_rt_b"][0]
autoreg_rt_rv = DistributionalRV(
"autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b)
)
generation_interval_pmf_rv = DeterministicVariable(
"generation_interval_pmf", jnp.array(stan_data["generation_interval"])
)
infection_feedback_pmf_rv = DeterministicVariable(
"infection_feedback_pmf", jnp.array(stan_data["infection_feedback_pmf"])
)
inf_feedback_prior_logmean = stan_data["inf_feedback_prior_logmean"][0]
inf_feedback_prior_logsd = stan_data["inf_feedback_prior_logsd"][0]
inf_feedback_strength_rv = TransformedRandomVariable(
"inf_feedback",
DistributionalRV(
"inf_feedback_raw",
dist.LogNormal(inf_feedback_prior_logmean, inf_feedback_prior_logsd),
),
transforms=transforms.AffineTransform(loc=0, scale=-1),
)
# Could be reparameterized?
p_hosp_prior_mean = stan_data["p_hosp_prior_mean"][0]
p_hosp_sd_logit = stan_data["p_hosp_sd_logit"][0]
p_hosp_mean_rv = DistributionalRV(
"p_hosp_mean",
dist.Normal(transforms.logit(p_hosp_prior_mean), p_hosp_sd_logit),
) # logit scale
p_hosp_w_sd_sd = stan_data["p_hosp_w_sd_sd"][0]
p_hosp_w_sd_rv = DistributionalRV(
"p_hosp_w_sd_sd", dist.TruncatedNormal(0, p_hosp_w_sd_sd, low=0)
)
autoreg_p_hosp_a = stan_data["autoreg_p_hosp_a"][0]
autoreg_p_hosp_b = stan_data["autoreg_p_hosp_b"][0]
autoreg_p_hosp_rv = DistributionalRV(
"autoreg_p_hosp", dist.Beta(autoreg_p_hosp_a, autoreg_p_hosp_b)
)
# hosp_wday_effect ~ normal(effect_mean, wday_effect_prior_sd);
# wday_effect_prior_mean = stan_data["wday_effect_prior_mean"][0]
# wday_effect_prior_sd = stan_data["wday_effect_prior_sd"][0]
# Instead of the above, use a Dirichlet prior (see https://github.com/CDCgov/ww-inference-model/issues/42)
hosp_wday_effect_rv = TransformedRandomVariable(
"hosp_wday_effect",
DistributionalRV(
"hosp_wday_effect_raw", dist.Dirichlet(concentration=jnp.ones(7))
),
transforms.AffineTransform(loc=0, scale=7),
)
inf_to_hosp_rv = DeterministicVariable(
"inf_to_hosp", jnp.array(stan_data["inf_to_hosp"])
)
inv_sqrt_phi_prior_mean = stan_data["inv_sqrt_phi_prior_mean"][0]
inv_sqrt_phi_prior_sd = stan_data["inv_sqrt_phi_prior_sd"][0]
phi_rv = TransformedRandomVariable(
"phi",
DistributionalRV(
"inv_sqrt_phi",
dist.TruncatedNormal(
loc=inv_sqrt_phi_prior_mean,
scale=inv_sqrt_phi_prior_sd,
low=1 / jnp.sqrt(5000),
),
),
transforms=transforms.PowerTransform(-2),
)
uot = stan_data["uot"][0]
state_pop = stan_data["state_pop"][0]
data_observed_hospital_admissions = jnp.array(stan_data["hosp"])
```

# Simulate from the model

Next, we define the model:

```{python}
# | label: define the model
my_model = hosp_only_ww_model(
state_pop=state_pop,
i0_over_n_rv=i0_over_n_rv,
initialization_rate_rv=initialization_rate_rv,
log_r_mu_intercept_rv=log_r_mu_intercept_rv,
autoreg_rt_rv=autoreg_rt_rv, # ar process
eta_sd_rv=eta_sd_rv, # sd of random walk for ar process,
generation_interval_pmf_rv=generation_interval_pmf_rv,
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
infection_feedback_strength_rv=inf_feedback_strength_rv,
p_hosp_mean_rv=p_hosp_mean_rv,
p_hosp_w_sd_rv=p_hosp_w_sd_rv,
autoreg_p_hosp_rv=autoreg_p_hosp_rv,
hosp_wday_effect_rv=hosp_wday_effect_rv,
phi_rv=phi_rv,
inf_to_hosp_rv=inf_to_hosp_rv,
# n_initialization_points=uot,
n_initialization_points=len(jnp.array(stan_data["inf_to_hosp"])),
# i0_t_offset=-50, # to match stan model
i0_t_offset=0, # a better way of parameterizing
)
```


We check that we can simulate from the prior predictive
```{python}
# | label: prior predictive
# | eval: false
# for some reason the posterior inference crashes if we do the prior predictive first
prior_predictive = my_model.prior_predictive(
n_datapoints=len(data_observed_hospital_admissions),
numpyro_predictive_args={"num_samples": 200},
)
```

# Fit the model

Now we can fit the model to the observed data:
```{python}
# | label: fit the model
my_model.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.key(200),
data_observed_hospital_admissions=data_observed_hospital_admissions,
mcmc_args=dict(num_chains=4),
)
```

Check the posterior predictive:

```{python}
# | label: posterior predictive
posterior_predictive = my_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions)
)
```

Forecasting is broken (dependent on https://github.com/CDCgov/multisignal-epi-inference/issues/328)

```{python}
# | label: posterior forecast
# | eval: false
my_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + 2
)
```


## Prepare for plotting
```{python}
import arviz as az
idata = az.from_numpyro(
my_model.mcmc, posterior_predictive=posterior_predictive
)
```

```{python}
def compute_eti(dataset, eti_prob):
eti_bdry = dataset.quantile(
((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw")
)
return eti_bdry.values.T
```


```{python}
import matplotlib.pyplot as plt
def plot_posterior(name, predictive=False):
if predictive:
posterior_object = idata.posterior_predictive
else:
posterior_object = idata.posterior
x_data = posterior_object[f"{name}_dim_0"]
y_data = posterior_object[name]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.9),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.5),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add median of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])
plt.plot(
x_data,
median_ts,
color="C0",
label="Median",
)
axes.legend()
axes.set_title(name, fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel(name, fontsize=10)
return fig
```

## Plot all posteriors
# Do we know why some univariate sites have a dimension and others do not?
```{python}
for key in list(idata.posterior.keys()):
try:
fig = plot_posterior(key)
fig.show()
except Exception as e:
print(f"An error occurred while plotting {key}: {e}")
```

## Posterior predictive hospital admissions
```{python}
import matplotlib.pyplot as plt
x_data = idata.posterior_predictive["observed_hospital_admissions_dim_0"] + uot
y_data = idata.posterior_predictive["observed_hospital_admissions"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.9),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.5),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add median of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])
plt.plot(
x_data,
median_ts,
color="C0",
label="Median",
)
plt.scatter(
idata.observed_data["observed_hospital_admissions_dim_0"] + uot,
idata.observed_data["observed_hospital_admissions"],
color="black",
)
axes.legend()
axes.set_title("Posterior Predictive Admissions", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Hospital Admissions", fontsize=10)
plt.show()
```



```{python}
idata.posterior["log_r_mu_intercept_rv"]
az.summary(
idata,
var_names=["log_rt", "periodic_diff_sd", "autoreg_rt_det", "rtu"],
stat_focus="median",
)
```
Why is log_rt 2 dimensional?
Loading

0 comments on commit d4e1984

Please sign in to comment.