Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pyrenew_hew_model sample signature for PyRenew-HE #265

Merged
merged 24 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demos/hosp_only_ww_model/pyrenew_hew_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ We check that we can simulate from the prior predictive
n_forecast_days = 35

prior_predictive = my_pyrenew_hew_model.prior_predictive(
n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days,
n_observed_disease_ed_visits_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days,
numpyro_predictive_args={"num_samples": 200},
)
```
Expand All @@ -75,7 +75,7 @@ Create the posterior predictive and forecast:
```{python}
# | label: posterior predictive
posterior_predictive = my_pyrenew_hew_model.posterior_predictive(
n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days
n_observed_disease_ed_visits_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days
)
```

Expand Down
12 changes: 6 additions & 6 deletions demos/ww_model/ww_model_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ max_ww_sampled_days = max(stan_data["ww_sampled_times"])
lab_site_to_subpop_map = jnp.array(stan_data["lab_site_to_subpop_map"]) -1 #vector mapping the subpops to lab-site combos

data_observed_log_conc = jnp.array(stan_data["log_conc"])
data_observed_hospital_admissions = jnp.array(stan_data["hosp"])
data_observed_disease_hospital_admissions = jnp.array(stan_data["hosp"])
```

```{python}
Expand Down Expand Up @@ -327,7 +327,7 @@ Check that we can simulate from the prior predictive
n_forecast_days = 35

prior_predictive = my_model.prior_predictive(
n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days,
n_observed_disease_ed_visits_datapoints= max(len(data_observed_disease_hospital_admissions), max_ww_sampled_days) + n_forecast_days,
numpyro_predictive_args={"num_samples": 100},
)
```
Expand All @@ -339,7 +339,7 @@ my_model.run(
num_warmup=750,
num_samples=500,
rng_key=jax.random.key(223),
data_observed_hospital_admissions=data_observed_hospital_admissions,
data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions,
data_observed_log_conc=data_observed_log_conc,
mcmc_args=dict(num_chains=4)
)
Expand All @@ -349,7 +349,7 @@ Simulate the posterior predictive distribution
```{python}
# | label: posterior predictive
posterior_predictive = my_model.posterior_predictive(
n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, is_predictive=True
n_observed_disease_ed_visits_datapoints= max(len(data_observed_disease_hospital_admissions), max_ww_sampled_days) + n_forecast_days, is_predictive=True
)
```

Expand Down Expand Up @@ -437,14 +437,14 @@ my_model_hosp_only_fit.run(
num_warmup=750,
num_samples=500,
rng_key=jax.random.key(223),
data_observed_hospital_admissions=data_observed_hospital_admissions,
data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions,
mcmc_args=dict(num_chains=4)
)
```

```{python}
posterior_predictive_hosp_only = my_model_hosp_only_fit.posterior_predictive(
n_datapoints= len(data_observed_hospital_admissions) + n_forecast_days,is_predictive=True
n_observed_disease_ed_visits_datapoints= len(data_observed_disease_hospital_admissions) + n_forecast_days,is_predictive=True
)
```

Expand Down
4 changes: 4 additions & 0 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def build_model_from_dir(model_dir):
data_observed_disease_ed_visits = jnp.array(
model_data["data_observed_disease_ed_visits"]
)
data_observed_disease_hospital_admissions = jnp.array(
model_data["data_observed_disease_hospital_admissions"]
)
state_pop = jnp.array(model_data["state_pop"])

right_truncation_pmf_rv = DeterministicVariable(
Expand Down Expand Up @@ -75,5 +78,6 @@ def build_model_from_dir(model_dir):
return (
my_model,
data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions,
right_truncation_offset,
)
1 change: 1 addition & 0 deletions pipelines/fit_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def fit_and_save_model(
(
my_model,
data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions,
right_truncation_offset,
) = build_model_from_dir(model_run_dir)
my_model.run(
Expand Down
10 changes: 9 additions & 1 deletion pipelines/generate_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def generate_and_save_predictions(
(
my_model,
data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions,
right_truncation_offset,
) = build_model_from_dir(model_run_dir)

Expand All @@ -33,7 +34,14 @@ def generate_and_save_predictions(
my_model.mcmc.sampler = fresh_sampler

posterior_predictive = my_model.posterior_predictive(
n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_points
n_observed_disease_ed_visits_datapoints=len(
data_observed_disease_ed_visits
)
+ n_forecast_points,
n_observed_hospital_admissions_datapoints=len(
data_observed_disease_hospital_admissions
)
+ n_forecast_points // 7,
)

idata = az.from_numpyro(
Expand Down
2 changes: 1 addition & 1 deletion pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def process_and_save_state(
"generation_interval_pmf": generation_interval_pmf,
"right_truncation_pmf": right_truncation_pmf,
"data_observed_disease_ed_visits": train_disease_ed_visits,
"data_observed_total_hospital_admissions": train_total_ed_visits,
"data_observed_disease_hospital_admissions": train_total_ed_visits,
"data_observed_disease_hospital_admissions": train_disease_hospital_admissions,
"nssp_training_dates": nssp_training_dates,
"nhsn_training_dates": nhsn_training_dates,
Expand Down
58 changes: 42 additions & 16 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
InitializeInfectionsExponentialGrowth,
)
from pyrenew.metaclass import Model
from pyrenew.observation import NegativeBinomialObservation
from pyrenew.observation import NegativeBinomialObservation, PoissonObservation
from pyrenew.process import ARProcess, DifferencedProcess
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable

Expand Down Expand Up @@ -84,28 +84,41 @@ def validate(self): # numpydoc ignore=GL08

def sample(
self,
n_datapoints=None,
n_observed_disease_ed_visits_datapoints=None,
n_observed_hospital_admissions_datapoints=None,
data_observed_disease_ed_visits=None,
data_observed_disease_hospital_admissions=None,
right_truncation_offset=None,
): # numpydoc ignore=GL08
if n_datapoints is None and data_observed_disease_ed_visits is None:
if (
n_observed_disease_ed_visits_datapoints is None
and data_observed_disease_ed_visits is None
):
raise ValueError(
"Either n_datapoints or data_observed_disease_ed_visits "
"Either n_observed_disease_ed_visits_datapoints or data_observed_disease_ed_visits "
"must be passed."
)
elif (
n_datapoints is not None
n_observed_disease_ed_visits_datapoints is not None
and data_observed_disease_ed_visits is not None
):
raise ValueError(
"Cannot pass both n_datapoints and data_observed_disease_ed_visits."
"Cannot pass both n_observed_disease_ed_visits_datapoints and data_observed_disease_ed_visits."
)
elif n_observed_disease_ed_visits_datapoints is None:
n_observed_disease_ed_visits_datapoints = len(
data_observed_disease_ed_visits
)

if (
n_observed_hospital_admissions_datapoints is None
and data_observed_disease_hospital_admissions is not None
):
n_observed_hospital_admissions_datapoints = len(
data_observed_disease_hospital_admissions
)
elif n_datapoints is None:
n_datapoints = len(data_observed_disease_ed_visits)
else:
n_datapoints = n_datapoints

n_weeks_post_init = n_datapoints // 7 + 1
n_weeks_post_init = n_observed_disease_ed_visits_datapoints // 7 + 1
i0 = self.infection_initialization_process()

eta_sd = self.eta_sd_rv()
Expand All @@ -127,7 +140,7 @@ def sample(

rtu = repeat_until_n(
data=jnp.exp(log_rtu_weekly),
n_timepoints=n_datapoints,
n_timepoints=n_observed_disease_ed_visits_datapoints,
offset=0,
period_size=7,
)
Expand Down Expand Up @@ -174,22 +187,24 @@ def sample(
iedr = jnp.repeat(
transformation.SigmoidTransform()(p_ed_ar + p_ed_mean),
repeats=7,
)[:n_datapoints]
)[:n_observed_disease_ed_visits_datapoints]
# this is only applied after the ed visits are generated, not to all the latent infections. This is why we cannot apply the iedr in compute_delay_ascertained_incidence
# see https://github.com/CDCgov/ww-inference-model/issues/43

numpyro.deterministic("iedr", iedr)

ed_wday_effect_raw = self.ed_wday_effect_rv()
ed_wday_effect = tile_until_n(ed_wday_effect_raw, n_datapoints)
ed_wday_effect = tile_until_n(
ed_wday_effect_raw, n_observed_disease_ed_visits_datapoints
)

inf_to_ed = self.inf_to_ed_rv()

potential_latent_ed_visits = compute_delay_ascertained_incidence(
p_observed_given_incident=1,
latent_incidence=latent_infections,
delay_incidence_to_observation_pmf=inf_to_ed,
)[-n_datapoints:]
)[-n_observed_disease_ed_visits_datapoints:]

latent_ed_visits_final = (
potential_latent_ed_visits * iedr * ed_wday_effect * self.state_pop
Expand All @@ -200,7 +215,8 @@ def sample(
self.right_truncation_cdf_rv()[right_truncation_offset:]
)
n_points_to_prepend = (
n_datapoints - prop_already_reported_tail.shape[0]
n_observed_disease_ed_visits_datapoints
- prop_already_reported_tail.shape[0]
)
prop_already_reported = jnp.pad(
prop_already_reported_tail,
Expand All @@ -223,6 +239,16 @@ def sample(
obs=data_observed_disease_ed_visits,
)

if n_observed_hospital_admissions_datapoints is not None:
hospital_admissions_obs_rv = PoissonObservation(
"observed_hospital_admissions"
)
data_observed_disease_hospital_admissions = (
hospital_admissions_obs_rv(
mu=jnp.ones(n_observed_hospital_admissions_datapoints) + 50
)
)

return observed_ed_visits


Expand Down