Skip to content

Commit

Permalink
Update pyrenew_hew_model sample signature for PyRenew-HE (#265)
Browse files Browse the repository at this point in the history
* Create README.md

* move/rename build_model.py

* delete unused files

* move and rename

* change import references

* pre-commit

* adjust imports

* pre-commit

* remove extra forecast_state

* undo lots of changes

* restore readme

* remove more inits

* remove another init

* correct import

* pre-commit

* rename n_datapoints

* recommit n_datapoints change

* allow observing hospital admissions

* update predictive and fix names

---------

Co-authored-by: Dylan H. Morris <[email protected]>
  • Loading branch information
damonbayer and dylanhmorris authored Dec 30, 2024
1 parent 74a2a3f commit d5c0199
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 25 deletions.
4 changes: 2 additions & 2 deletions demos/hosp_only_ww_model/pyrenew_hew_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ We check that we can simulate from the prior predictive
n_forecast_days = 35
prior_predictive = my_pyrenew_hew_model.prior_predictive(
n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days,
n_observed_disease_ed_visits_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days,
numpyro_predictive_args={"num_samples": 200},
)
```
Expand All @@ -75,7 +75,7 @@ Create the posterior predictive and forecast:
```{python}
# | label: posterior predictive
posterior_predictive = my_pyrenew_hew_model.posterior_predictive(
n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days
n_observed_disease_ed_visits_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days
)
```

Expand Down
12 changes: 6 additions & 6 deletions demos/ww_model/ww_model_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ max_ww_sampled_days = max(stan_data["ww_sampled_times"])
lab_site_to_subpop_map = jnp.array(stan_data["lab_site_to_subpop_map"]) -1 #vector mapping the subpops to lab-site combos
data_observed_log_conc = jnp.array(stan_data["log_conc"])
data_observed_hospital_admissions = jnp.array(stan_data["hosp"])
data_observed_disease_hospital_admissions = jnp.array(stan_data["hosp"])
```

```{python}
Expand Down Expand Up @@ -327,7 +327,7 @@ Check that we can simulate from the prior predictive
n_forecast_days = 35
prior_predictive = my_model.prior_predictive(
n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days,
n_observed_disease_ed_visits_datapoints= max(len(data_observed_disease_hospital_admissions), max_ww_sampled_days) + n_forecast_days,
numpyro_predictive_args={"num_samples": 100},
)
```
Expand All @@ -339,7 +339,7 @@ my_model.run(
num_warmup=750,
num_samples=500,
rng_key=jax.random.key(223),
data_observed_hospital_admissions=data_observed_hospital_admissions,
data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions,
data_observed_log_conc=data_observed_log_conc,
mcmc_args=dict(num_chains=4)
)
Expand All @@ -349,7 +349,7 @@ Simulate the posterior predictive distribution
```{python}
# | label: posterior predictive
posterior_predictive = my_model.posterior_predictive(
n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, is_predictive=True
n_observed_disease_ed_visits_datapoints= max(len(data_observed_disease_hospital_admissions), max_ww_sampled_days) + n_forecast_days, is_predictive=True
)
```

Expand Down Expand Up @@ -437,14 +437,14 @@ my_model_hosp_only_fit.run(
num_warmup=750,
num_samples=500,
rng_key=jax.random.key(223),
data_observed_hospital_admissions=data_observed_hospital_admissions,
data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions,
mcmc_args=dict(num_chains=4)
)
```

```{python}
posterior_predictive_hosp_only = my_model_hosp_only_fit.posterior_predictive(
n_datapoints= len(data_observed_hospital_admissions) + n_forecast_days,is_predictive=True
n_observed_disease_ed_visits_datapoints= len(data_observed_disease_hospital_admissions) + n_forecast_days,is_predictive=True
)
```

Expand Down
4 changes: 4 additions & 0 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def build_model_from_dir(model_dir):
data_observed_disease_ed_visits = jnp.array(
model_data["data_observed_disease_ed_visits"]
)
data_observed_disease_hospital_admissions = jnp.array(
model_data["data_observed_disease_hospital_admissions"]
)
state_pop = jnp.array(model_data["state_pop"])

right_truncation_pmf_rv = DeterministicVariable(
Expand Down Expand Up @@ -75,5 +78,6 @@ def build_model_from_dir(model_dir):
return (
my_model,
data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions,
right_truncation_offset,
)
1 change: 1 addition & 0 deletions pipelines/fit_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def fit_and_save_model(
(
my_model,
data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions,
right_truncation_offset,
) = build_model_from_dir(model_run_dir)
my_model.run(
Expand Down
10 changes: 9 additions & 1 deletion pipelines/generate_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def generate_and_save_predictions(
(
my_model,
data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions,
right_truncation_offset,
) = build_model_from_dir(model_run_dir)

Expand All @@ -33,7 +34,14 @@ def generate_and_save_predictions(
my_model.mcmc.sampler = fresh_sampler

posterior_predictive = my_model.posterior_predictive(
n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_points
n_observed_disease_ed_visits_datapoints=len(
data_observed_disease_ed_visits
)
+ n_forecast_points,
n_observed_hospital_admissions_datapoints=len(
data_observed_disease_hospital_admissions
)
+ n_forecast_points // 7,
)

idata = az.from_numpyro(
Expand Down
58 changes: 42 additions & 16 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
InitializeInfectionsExponentialGrowth,
)
from pyrenew.metaclass import Model
from pyrenew.observation import NegativeBinomialObservation
from pyrenew.observation import NegativeBinomialObservation, PoissonObservation
from pyrenew.process import ARProcess, DifferencedProcess
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable

Expand Down Expand Up @@ -84,28 +84,41 @@ def validate(self): # numpydoc ignore=GL08

def sample(
self,
n_datapoints=None,
n_observed_disease_ed_visits_datapoints=None,
n_observed_hospital_admissions_datapoints=None,
data_observed_disease_ed_visits=None,
data_observed_disease_hospital_admissions=None,
right_truncation_offset=None,
): # numpydoc ignore=GL08
if n_datapoints is None and data_observed_disease_ed_visits is None:
if (
n_observed_disease_ed_visits_datapoints is None
and data_observed_disease_ed_visits is None
):
raise ValueError(
"Either n_datapoints or data_observed_disease_ed_visits "
"Either n_observed_disease_ed_visits_datapoints or data_observed_disease_ed_visits "
"must be passed."
)
elif (
n_datapoints is not None
n_observed_disease_ed_visits_datapoints is not None
and data_observed_disease_ed_visits is not None
):
raise ValueError(
"Cannot pass both n_datapoints and data_observed_disease_ed_visits."
"Cannot pass both n_observed_disease_ed_visits_datapoints and data_observed_disease_ed_visits."
)
elif n_observed_disease_ed_visits_datapoints is None:
n_observed_disease_ed_visits_datapoints = len(
data_observed_disease_ed_visits
)

if (
n_observed_hospital_admissions_datapoints is None
and data_observed_disease_hospital_admissions is not None
):
n_observed_hospital_admissions_datapoints = len(
data_observed_disease_hospital_admissions
)
elif n_datapoints is None:
n_datapoints = len(data_observed_disease_ed_visits)
else:
n_datapoints = n_datapoints

n_weeks_post_init = n_datapoints // 7 + 1
n_weeks_post_init = n_observed_disease_ed_visits_datapoints // 7 + 1
i0 = self.infection_initialization_process()

eta_sd = self.eta_sd_rv()
Expand All @@ -127,7 +140,7 @@ def sample(

rtu = repeat_until_n(
data=jnp.exp(log_rtu_weekly),
n_timepoints=n_datapoints,
n_timepoints=n_observed_disease_ed_visits_datapoints,
offset=0,
period_size=7,
)
Expand Down Expand Up @@ -174,22 +187,24 @@ def sample(
iedr = jnp.repeat(
transformation.SigmoidTransform()(p_ed_ar + p_ed_mean),
repeats=7,
)[:n_datapoints]
)[:n_observed_disease_ed_visits_datapoints]
# this is only applied after the ed visits are generated, not to all the latent infections. This is why we cannot apply the iedr in compute_delay_ascertained_incidence
# see https://github.com/CDCgov/ww-inference-model/issues/43

numpyro.deterministic("iedr", iedr)

ed_wday_effect_raw = self.ed_wday_effect_rv()
ed_wday_effect = tile_until_n(ed_wday_effect_raw, n_datapoints)
ed_wday_effect = tile_until_n(
ed_wday_effect_raw, n_observed_disease_ed_visits_datapoints
)

inf_to_ed = self.inf_to_ed_rv()

potential_latent_ed_visits = compute_delay_ascertained_incidence(
p_observed_given_incident=1,
latent_incidence=latent_infections,
delay_incidence_to_observation_pmf=inf_to_ed,
)[-n_datapoints:]
)[-n_observed_disease_ed_visits_datapoints:]

latent_ed_visits_final = (
potential_latent_ed_visits * iedr * ed_wday_effect * self.state_pop
Expand All @@ -200,7 +215,8 @@ def sample(
self.right_truncation_cdf_rv()[right_truncation_offset:]
)
n_points_to_prepend = (
n_datapoints - prop_already_reported_tail.shape[0]
n_observed_disease_ed_visits_datapoints
- prop_already_reported_tail.shape[0]
)
prop_already_reported = jnp.pad(
prop_already_reported_tail,
Expand All @@ -223,6 +239,16 @@ def sample(
obs=data_observed_disease_ed_visits,
)

if n_observed_hospital_admissions_datapoints is not None:
hospital_admissions_obs_rv = PoissonObservation(
"observed_hospital_admissions"
)
data_observed_disease_hospital_admissions = (
hospital_admissions_obs_rv(
mu=jnp.ones(n_observed_hospital_admissions_datapoints) + 50
)
)

return observed_ed_visits


Expand Down

0 comments on commit d5c0199

Please sign in to comment.