diff --git a/demos/hosp_only_ww_model/pyrenew_hew_model.qmd b/demos/hosp_only_ww_model/pyrenew_hew_model.qmd index 6f6e4478..3e993fe0 100644 --- a/demos/hosp_only_ww_model/pyrenew_hew_model.qmd +++ b/demos/hosp_only_ww_model/pyrenew_hew_model.qmd @@ -50,7 +50,7 @@ We check that we can simulate from the prior predictive n_forecast_days = 35 prior_predictive = my_pyrenew_hew_model.prior_predictive( - n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days, + n_observed_disease_ed_visits_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days, numpyro_predictive_args={"num_samples": 200}, ) ``` @@ -75,7 +75,7 @@ Create the posterior predictive and forecast: ```{python} # | label: posterior predictive posterior_predictive = my_pyrenew_hew_model.posterior_predictive( - n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days + n_observed_disease_ed_visits_datapoints=len(data_observed_disease_ed_visits) + n_forecast_days ) ``` diff --git a/demos/ww_model/ww_model_demo.qmd b/demos/ww_model/ww_model_demo.qmd index 9d7d2932..6747af6c 100644 --- a/demos/ww_model/ww_model_demo.qmd +++ b/demos/ww_model/ww_model_demo.qmd @@ -62,7 +62,7 @@ max_ww_sampled_days = max(stan_data["ww_sampled_times"]) lab_site_to_subpop_map = jnp.array(stan_data["lab_site_to_subpop_map"]) -1 #vector mapping the subpops to lab-site combos data_observed_log_conc = jnp.array(stan_data["log_conc"]) -data_observed_hospital_admissions = jnp.array(stan_data["hosp"]) +data_observed_disease_hospital_admissions = jnp.array(stan_data["hosp"]) ``` ```{python} @@ -327,7 +327,7 @@ Check that we can simulate from the prior predictive n_forecast_days = 35 prior_predictive = my_model.prior_predictive( - n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, + n_observed_disease_ed_visits_datapoints= max(len(data_observed_disease_hospital_admissions), max_ww_sampled_days) + n_forecast_days, numpyro_predictive_args={"num_samples": 100}, ) ``` @@ -339,7 +339,7 @@ my_model.run( num_warmup=750, num_samples=500, rng_key=jax.random.key(223), - data_observed_hospital_admissions=data_observed_hospital_admissions, + data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions, data_observed_log_conc=data_observed_log_conc, mcmc_args=dict(num_chains=4) ) @@ -349,7 +349,7 @@ Simulate the posterior predictive distribution ```{python} # | label: posterior predictive posterior_predictive = my_model.posterior_predictive( - n_datapoints= max(len(data_observed_hospital_admissions), max_ww_sampled_days) + n_forecast_days, is_predictive=True + n_observed_disease_ed_visits_datapoints= max(len(data_observed_disease_hospital_admissions), max_ww_sampled_days) + n_forecast_days, is_predictive=True ) ``` @@ -437,14 +437,14 @@ my_model_hosp_only_fit.run( num_warmup=750, num_samples=500, rng_key=jax.random.key(223), - data_observed_hospital_admissions=data_observed_hospital_admissions, + data_observed_disease_hospital_admissions=data_observed_disease_hospital_admissions, mcmc_args=dict(num_chains=4) ) ``` ```{python} posterior_predictive_hosp_only = my_model_hosp_only_fit.posterior_predictive( - n_datapoints= len(data_observed_hospital_admissions) + n_forecast_days,is_predictive=True + n_observed_disease_ed_visits_datapoints= len(data_observed_disease_hospital_admissions) + n_forecast_days,is_predictive=True ) ``` diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 5d062f38..115c031c 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -34,6 +34,9 @@ def build_model_from_dir(model_dir): data_observed_disease_ed_visits = jnp.array( model_data["data_observed_disease_ed_visits"] ) + data_observed_disease_hospital_admissions = jnp.array( + model_data["data_observed_disease_hospital_admissions"] + ) state_pop = jnp.array(model_data["state_pop"]) right_truncation_pmf_rv = DeterministicVariable( @@ -75,5 +78,6 @@ def build_model_from_dir(model_dir): return ( my_model, data_observed_disease_ed_visits, + data_observed_disease_hospital_admissions, right_truncation_offset, ) diff --git a/pipelines/fit_pyrenew_model.py b/pipelines/fit_pyrenew_model.py index a63278db..43ce5442 100644 --- a/pipelines/fit_pyrenew_model.py +++ b/pipelines/fit_pyrenew_model.py @@ -29,6 +29,7 @@ def fit_and_save_model( ( my_model, data_observed_disease_ed_visits, + data_observed_disease_hospital_admissions, right_truncation_offset, ) = build_model_from_dir(model_run_dir) my_model.run( diff --git a/pipelines/generate_predictive.py b/pipelines/generate_predictive.py index fb3a5c0d..e483af8c 100644 --- a/pipelines/generate_predictive.py +++ b/pipelines/generate_predictive.py @@ -18,6 +18,7 @@ def generate_and_save_predictions( ( my_model, data_observed_disease_ed_visits, + data_observed_disease_hospital_admissions, right_truncation_offset, ) = build_model_from_dir(model_run_dir) @@ -33,7 +34,14 @@ def generate_and_save_predictions( my_model.mcmc.sampler = fresh_sampler posterior_predictive = my_model.posterior_predictive( - n_datapoints=len(data_observed_disease_ed_visits) + n_forecast_points + n_observed_disease_ed_visits_datapoints=len( + data_observed_disease_ed_visits + ) + + n_forecast_points, + n_observed_hospital_admissions_datapoints=len( + data_observed_disease_hospital_admissions + ) + + n_forecast_points // 7, ) idata = az.from_numpyro( diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 0aac754f..88658739 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -14,7 +14,7 @@ InitializeInfectionsExponentialGrowth, ) from pyrenew.metaclass import Model -from pyrenew.observation import NegativeBinomialObservation +from pyrenew.observation import NegativeBinomialObservation, PoissonObservation from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable @@ -84,28 +84,41 @@ def validate(self): # numpydoc ignore=GL08 def sample( self, - n_datapoints=None, + n_observed_disease_ed_visits_datapoints=None, + n_observed_hospital_admissions_datapoints=None, data_observed_disease_ed_visits=None, + data_observed_disease_hospital_admissions=None, right_truncation_offset=None, ): # numpydoc ignore=GL08 - if n_datapoints is None and data_observed_disease_ed_visits is None: + if ( + n_observed_disease_ed_visits_datapoints is None + and data_observed_disease_ed_visits is None + ): raise ValueError( - "Either n_datapoints or data_observed_disease_ed_visits " + "Either n_observed_disease_ed_visits_datapoints or data_observed_disease_ed_visits " "must be passed." ) elif ( - n_datapoints is not None + n_observed_disease_ed_visits_datapoints is not None and data_observed_disease_ed_visits is not None ): raise ValueError( - "Cannot pass both n_datapoints and data_observed_disease_ed_visits." + "Cannot pass both n_observed_disease_ed_visits_datapoints and data_observed_disease_ed_visits." + ) + elif n_observed_disease_ed_visits_datapoints is None: + n_observed_disease_ed_visits_datapoints = len( + data_observed_disease_ed_visits + ) + + if ( + n_observed_hospital_admissions_datapoints is None + and data_observed_disease_hospital_admissions is not None + ): + n_observed_hospital_admissions_datapoints = len( + data_observed_disease_hospital_admissions ) - elif n_datapoints is None: - n_datapoints = len(data_observed_disease_ed_visits) - else: - n_datapoints = n_datapoints - n_weeks_post_init = n_datapoints // 7 + 1 + n_weeks_post_init = n_observed_disease_ed_visits_datapoints // 7 + 1 i0 = self.infection_initialization_process() eta_sd = self.eta_sd_rv() @@ -127,7 +140,7 @@ def sample( rtu = repeat_until_n( data=jnp.exp(log_rtu_weekly), - n_timepoints=n_datapoints, + n_timepoints=n_observed_disease_ed_visits_datapoints, offset=0, period_size=7, ) @@ -174,14 +187,16 @@ def sample( iedr = jnp.repeat( transformation.SigmoidTransform()(p_ed_ar + p_ed_mean), repeats=7, - )[:n_datapoints] + )[:n_observed_disease_ed_visits_datapoints] # this is only applied after the ed visits are generated, not to all the latent infections. This is why we cannot apply the iedr in compute_delay_ascertained_incidence # see https://github.com/CDCgov/ww-inference-model/issues/43 numpyro.deterministic("iedr", iedr) ed_wday_effect_raw = self.ed_wday_effect_rv() - ed_wday_effect = tile_until_n(ed_wday_effect_raw, n_datapoints) + ed_wday_effect = tile_until_n( + ed_wday_effect_raw, n_observed_disease_ed_visits_datapoints + ) inf_to_ed = self.inf_to_ed_rv() @@ -189,7 +204,7 @@ def sample( p_observed_given_incident=1, latent_incidence=latent_infections, delay_incidence_to_observation_pmf=inf_to_ed, - )[-n_datapoints:] + )[-n_observed_disease_ed_visits_datapoints:] latent_ed_visits_final = ( potential_latent_ed_visits * iedr * ed_wday_effect * self.state_pop @@ -200,7 +215,8 @@ def sample( self.right_truncation_cdf_rv()[right_truncation_offset:] ) n_points_to_prepend = ( - n_datapoints - prop_already_reported_tail.shape[0] + n_observed_disease_ed_visits_datapoints + - prop_already_reported_tail.shape[0] ) prop_already_reported = jnp.pad( prop_already_reported_tail, @@ -223,6 +239,16 @@ def sample( obs=data_observed_disease_ed_visits, ) + if n_observed_hospital_admissions_datapoints is not None: + hospital_admissions_obs_rv = PoissonObservation( + "observed_hospital_admissions" + ) + data_observed_disease_hospital_admissions = ( + hospital_admissions_obs_rv( + mu=jnp.ones(n_observed_hospital_admissions_datapoints) + 50 + ) + ) + return observed_ed_visits