Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pyrenew_hew_model sample signature for PyRenew-HE #265

Merged
merged 24 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demos/hosp_only_ww_model/pyrenew_hew_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ We check that we can simulate from the prior predictive
n_forecast_days = 35

prior_predictive = my_pyrenew_hew_model.prior_predictive(
n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days,
n_observed_disease_ed_visits_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days,
numpyro_predictive_args={"num_samples": 200},
)
```
Expand All @@ -75,7 +75,7 @@ Create the posterior predictive and forecast:
```{python}
# | label: posterior predictive
posterior_predictive = my_pyrenew_hew_model.posterior_predictive(
n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days
n_observed_disease_ed_visits_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days
)
```

Expand Down
12 changes: 6 additions & 6 deletions demos/ww_model/ww_model_demo.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ max_ww_sampled_days = max(stan_data["ww_sampled_times"])
lab_site_to_subpop_map = jnp.array(stan_data["lab_site_to_subpop_map"]) -1 #vector mapping the subpops to lab-site combos

data_observed_log_conc = jnp.array(stan_data["log_conc"])
data_observed_hospital_admissions = jnp.array(stan_data["hosp"])
data_observed_disease_hospital_admissions = jnp.array(stan_data["hosp"])
```

```{python}
Expand Down Expand Up @@ -327,7 +327,7 @@ Check that we can simulate from the prior predictive
n_forecast_days = 35

prior_predictive = my_model.prior_predictive(
n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days,
n_observed_disease_ed_visits_datapoints= max(len(data_observed_disease_hospital_admissions), max_ww_sampled_days) + n_forecast_days,
numpyro_predictive_args={"num_samples": 100},
)
```
Expand All @@ -339,7 +339,7 @@ my_model.run(
num_warmup=750,
num_samples=500,
rng_key=jax.random.key(223),
data_observed_hospital_admissions=data_observed_hospital_admissions,
data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions,
data_observed_log_conc=data_observed_log_conc,
mcmc_args=dict(num_chains=4)
)
Expand All @@ -349,7 +349,7 @@ Simulate the posterior predictive distribution
```{python}
# | label: posterior predictive
posterior_predictive = my_model.posterior_predictive(
n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, is_predictive=True
n_observed_disease_ed_visits_datapoints= max(len(data_observed_disease_hospital_admissions), max_ww_sampled_days) + n_forecast_days, is_predictive=True
)
```

Expand Down Expand Up @@ -437,14 +437,14 @@ my_model_hosp_only_fit.run(
num_warmup=750,
num_samples=500,
rng_key=jax.random.key(223),
data_observed_hospital_admissions=data_observed_hospital_admissions,
data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions,
mcmc_args=dict(num_chains=4)
)
```

```{python}
posterior_predictive_hosp_only = my_model_hosp_only_fit.posterior_predictive(
n_datapoints= len(data_observed_hospital_admissions) + n_forecast_days,is_predictive=True
n_observed_disease_ed_visits_datapoints= len(data_observed_disease_hospital_admissions) + n_forecast_days,is_predictive=True
)
```

Expand Down
4 changes: 4 additions & 0 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
data_observed_disease_ed_visits = jnp.array(
model_data["data_observed_disease_ed_visits"]
)
data_observed_disease_hospital_admissions = jnp.array(

Check warning on line 37 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L37

Added line #L37 was not covered by tests
model_data["data_observed_disease_hospital_admissions"]
)
state_pop = jnp.array(model_data["state_pop"])

right_truncation_pmf_rv = DeterministicVariable(
Expand Down Expand Up @@ -75,5 +78,6 @@
return (
my_model,
data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions,
right_truncation_offset,
)
1 change: 1 addition & 0 deletions pipelines/fit_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def fit_and_save_model(
(
my_model,
data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions,
right_truncation_offset,
) = build_model_from_dir(model_run_dir)
my_model.run(
Expand Down
10 changes: 9 additions & 1 deletion pipelines/generate_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def generate_and_save_predictions(
(
my_model,
data_observed_disease_ed_visits,
data_observed_disease_hospital_admissions,
right_truncation_offset,
) = build_model_from_dir(model_run_dir)

Expand All @@ -33,7 +34,14 @@ def generate_and_save_predictions(
my_model.mcmc.sampler = fresh_sampler

posterior_predictive = my_model.posterior_predictive(
n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_points
n_observed_disease_ed_visits_datapoints=len(
data_observed_disease_ed_visits
)
+ n_forecast_points,
n_observed_hospital_admissions_datapoints=len(
data_observed_disease_hospital_admissions
)
+ n_forecast_points // 7,
)

idata = az.from_numpyro(
Expand Down
58 changes: 42 additions & 16 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
InitializeInfectionsExponentialGrowth,
)
from pyrenew.metaclass import Model
from pyrenew.observation import NegativeBinomialObservation
from pyrenew.observation import NegativeBinomialObservation, PoissonObservation

Check warning on line 17 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L17

Added line #L17 was not covered by tests
from pyrenew.process import ARProcess, DifferencedProcess
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable

Expand Down Expand Up @@ -84,28 +84,41 @@

def sample(
self,
n_datapoints=None,
n_observed_disease_ed_visits_datapoints=None,
n_observed_hospital_admissions_datapoints=None,
data_observed_disease_ed_visits=None,
data_observed_disease_hospital_admissions=None,
right_truncation_offset=None,
): # numpydoc ignore=GL08
if n_datapoints is None and data_observed_disease_ed_visits is None:
if (

Check warning on line 93 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L93

Added line #L93 was not covered by tests
n_observed_disease_ed_visits_datapoints is None
and data_observed_disease_ed_visits is None
):
raise ValueError(
"Either n_datapoints or data_observed_disease_ed_visits "
"Either n_observed_disease_ed_visits_datapoints or data_observed_disease_ed_visits "
"must be passed."
)
elif (
n_datapoints is not None
n_observed_disease_ed_visits_datapoints is not None
and data_observed_disease_ed_visits is not None
):
raise ValueError(
"Cannot pass both n_datapoints and data_observed_disease_ed_visits."
"Cannot pass both n_observed_disease_ed_visits_datapoints and data_observed_disease_ed_visits."
)
elif n_observed_disease_ed_visits_datapoints is None:
n_observed_disease_ed_visits_datapoints = len(

Check warning on line 109 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L108-L109

Added lines #L108 - L109 were not covered by tests
data_observed_disease_ed_visits
)

if (

Check warning on line 113 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L113

Added line #L113 was not covered by tests
n_observed_hospital_admissions_datapoints is None
and data_observed_disease_hospital_admissions is not None
):
n_observed_hospital_admissions_datapoints = len(

Check warning on line 117 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L117

Added line #L117 was not covered by tests
data_observed_disease_hospital_admissions
)
elif n_datapoints is None:
n_datapoints = len(data_observed_disease_ed_visits)
else:
n_datapoints = n_datapoints

n_weeks_post_init = n_datapoints // 7 + 1
n_weeks_post_init = n_observed_disease_ed_visits_datapoints // 7 + 1

Check warning on line 121 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L121

Added line #L121 was not covered by tests
i0 = self.infection_initialization_process()

eta_sd = self.eta_sd_rv()
Expand All @@ -127,7 +140,7 @@

rtu = repeat_until_n(
data=jnp.exp(log_rtu_weekly),
n_timepoints=n_datapoints,
n_timepoints=n_observed_disease_ed_visits_datapoints,
offset=0,
period_size=7,
)
Expand Down Expand Up @@ -174,22 +187,24 @@
iedr = jnp.repeat(
transformation.SigmoidTransform()(p_ed_ar + p_ed_mean),
repeats=7,
)[:n_datapoints]
)[:n_observed_disease_ed_visits_datapoints]
# this is only applied after the ed visits are generated, not to all the latent infections. This is why we cannot apply the iedr in compute_delay_ascertained_incidence
# see https://github.com/CDCgov/ww-inference-model/issues/43

numpyro.deterministic("iedr", iedr)

ed_wday_effect_raw = self.ed_wday_effect_rv()
ed_wday_effect = tile_until_n(ed_wday_effect_raw, n_datapoints)
ed_wday_effect = tile_until_n(

Check warning on line 197 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L197

Added line #L197 was not covered by tests
ed_wday_effect_raw, n_observed_disease_ed_visits_datapoints
)

inf_to_ed = self.inf_to_ed_rv()

potential_latent_ed_visits = compute_delay_ascertained_incidence(
p_observed_given_incident=1,
latent_incidence=latent_infections,
delay_incidence_to_observation_pmf=inf_to_ed,
)[-n_datapoints:]
)[-n_observed_disease_ed_visits_datapoints:]

latent_ed_visits_final = (
potential_latent_ed_visits * iedr * ed_wday_effect * self.state_pop
Expand All @@ -200,7 +215,8 @@
self.right_truncation_cdf_rv()[right_truncation_offset:]
)
n_points_to_prepend = (
n_datapoints - prop_already_reported_tail.shape[0]
n_observed_disease_ed_visits_datapoints
- prop_already_reported_tail.shape[0]
)
prop_already_reported = jnp.pad(
prop_already_reported_tail,
Expand All @@ -223,6 +239,16 @@
obs=data_observed_disease_ed_visits,
)

if n_observed_hospital_admissions_datapoints is not None:
hospital_admissions_obs_rv = PoissonObservation(

Check warning on line 243 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L242-L243

Added lines #L242 - L243 were not covered by tests
"observed_hospital_admissions"
)
data_observed_disease_hospital_admissions = (

Check warning on line 246 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L246

Added line #L246 was not covered by tests
hospital_admissions_obs_rv(
mu=jnp.ones(n_observed_hospital_admissions_datapoints) + 50
)
)

return observed_ed_visits


Expand Down
Loading