From 5d790b74f63fd9f3165c2029eff31402fe4d02b9 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Tue, 7 Jan 2025 14:30:54 -0500 Subject: [PATCH] Modularize model in preparation for fitting real NHSN data (#277) --- pipelines/build_pyrenew_model.py | 46 +++- pipelines/priors/prod_priors.py | 9 +- pyrenew_hew/pyrenew_hew_model.py | 382 +++++++++++++++++++++++-------- 3 files changed, 332 insertions(+), 105 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 115c031c..f319d2e9 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -4,7 +4,13 @@ import jax.numpy as jnp from pyrenew.deterministic import DeterministicVariable -from pyrenew_hew.pyrenew_hew_model import pyrenew_hew_model +from pyrenew_hew.pyrenew_hew_model import ( + EDVisitObservationProcess, + HospAdmitObservationProcess, + LatentInfectionProcess, + PyrenewHEWModel, + WastewaterObservationProcess, +) def build_model_from_dir(model_dir): @@ -21,6 +27,11 @@ def build_model_from_dir(model_dir): "inf_to_ed", jnp.array(model_data["inf_to_ed_pmf"]) ) # check if off by 1 or reversed + # use same as inf to ed, per NNH guidelines + inf_to_hosp_admit_rv = DeterministicVariable( + "inf_to_hosp_admit", jnp.array(model_data["inf_to_ed_pmf"]) + ) # check if off by 1 or reversed + generation_interval_pmf_rv = DeterministicVariable( "generation_interval_pmf", jnp.array(model_data["generation_interval_pmf"]), @@ -37,9 +48,9 @@ def build_model_from_dir(model_dir): data_observed_disease_hospital_admissions = jnp.array( model_data["data_observed_disease_hospital_admissions"] ) - state_pop = jnp.array(model_data["state_pop"]) + population_size = jnp.array(model_data["state_pop"]) - right_truncation_pmf_rv = DeterministicVariable( + ed_right_truncation_pmf_rv = DeterministicVariable( "right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"]) ) @@ -55,8 +66,7 @@ def build_model_from_dir(model_dir): right_truncation_offset = model_data["right_truncation_offset"] - my_model = pyrenew_hew_model( - state_pop=state_pop, + my_latent_infection_model = LatentInfectionProcess( i0_first_obs_n_rv=priors["i0_first_obs_n_rv"], initialization_rate_rv=priors["initialization_rate_rv"], log_r_mu_intercept_rv=priors["log_r_mu_intercept_rv"], @@ -65,14 +75,34 @@ def build_model_from_dir(model_dir): generation_interval_pmf_rv=generation_interval_pmf_rv, infection_feedback_strength_rv=priors["inf_feedback_strength_rv"], infection_feedback_pmf_rv=infection_feedback_pmf_rv, + n_initialization_points=uot, + ) + + my_ed_visit_obs_model = EDVisitObservationProcess( p_ed_mean_rv=priors["p_ed_visit_mean_rv"], p_ed_w_sd_rv=priors["p_ed_visit_w_sd_rv"], autoreg_p_ed_rv=priors["autoreg_p_ed_visit_rv"], ed_wday_effect_rv=priors["ed_visit_wday_effect_rv"], inf_to_ed_rv=inf_to_ed_rv, - phi_rv=priors["phi_rv"], - right_truncation_pmf_rv=right_truncation_pmf_rv, - n_initialization_points=uot, + ed_neg_bin_concentration_rv=(priors["ed_neg_bin_concentration_rv"]), + ed_right_truncation_pmf_rv=ed_right_truncation_pmf_rv, + ) + + my_hosp_admit_obs_model = HospAdmitObservationProcess( + inf_to_hosp_admit_rv=inf_to_hosp_admit_rv, + hosp_admit_neg_bin_concentration_rv=( + priors["hosp_admit_neg_bin_concentration_rv"] + ), + ) + + my_wastewater_obs_model = WastewaterObservationProcess() + + my_model = PyrenewHEWModel( + population_size=population_size, + latent_infection_process_rv=my_latent_infection_model, + ed_visit_obs_process_rv=my_ed_visit_obs_model, + hosp_admit_obs_process_rv=my_hosp_admit_obs_model, + wastewater_obs_process_rv=my_wastewater_obs_model, ) return ( diff --git a/pipelines/priors/prod_priors.py b/pipelines/priors/prod_priors.py index 05c4a9f4..8625c20a 100644 --- a/pipelines/priors/prod_priors.py +++ b/pipelines/priors/prod_priors.py @@ -65,4 +65,11 @@ ) # Based on looking at some historical posteriors. -phi_rv = DistributionalVariable("phi", dist.LogNormal(4, 1)) +ed_neg_bin_concentration_rv = DistributionalVariable( + "ed_visit_neg_bin_concentration", dist.LogNormal(4, 1) +) + +# more diffuse than ED visit, same mean +hosp_admit_neg_bin_concentration_rv = DistributionalVariable( + "hosp_admit_neg_bin_concentration", dist.LogNormal(4, 2) +) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 88658739..b77f607e 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -5,6 +5,7 @@ import numpyro import numpyro.distributions as dist import pyrenew.transformation as transformation +from jax.typing import ArrayLike from pyrenew.arrayutils import repeat_until_n, tile_until_n from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.deterministic import DeterministicVariable @@ -13,7 +14,7 @@ InfectionsWithFeedback, InitializeInfectionsExponentialGrowth, ) -from pyrenew.metaclass import Model +from pyrenew.metaclass import Model, RandomVariable from pyrenew.observation import NegativeBinomialObservation, PoissonObservation from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.randomvariable import DistributionalVariable, TransformedVariable @@ -21,27 +22,19 @@ from pyrenew_hew.utils import convert_to_logmean_log_sd -class pyrenew_hew_model(Model): # numpydoc ignore=GL08 +class LatentInfectionProcess(RandomVariable): def __init__( self, - state_pop, - i0_first_obs_n_rv, - initialization_rate_rv, - log_r_mu_intercept_rv, - autoreg_rt_rv, # ar process - eta_sd_rv, # sd of random walk for ar process - generation_interval_pmf_rv, - infection_feedback_strength_rv, - infection_feedback_pmf_rv, - p_ed_mean_rv, - p_ed_w_sd_rv, - autoreg_p_ed_rv, - ed_wday_effect_rv, - inf_to_ed_rv, - phi_rv, - right_truncation_pmf_rv, # when unnamed deterministic variables are allowed, we could default this to 1. - n_initialization_points, - ): # numpydoc ignore=GL08 + i0_first_obs_n_rv: RandomVariable, + initialization_rate_rv: RandomVariable, + log_r_mu_intercept_rv: RandomVariable, + autoreg_rt_rv: RandomVariable, # ar coefficient of AR(1) process on R'(t) + eta_sd_rv: RandomVariable, # sd of random walk for ar process on R'(t) + generation_interval_pmf_rv: RandomVariable, + infection_feedback_strength_rv: RandomVariable, + infection_feedback_pmf_rv: RandomVariable, + n_initialization_points: int, + ) -> None: self.infection_initialization_process = InfectionInitializationProcess( "I0_initialization", i0_first_obs_n_rv, @@ -55,70 +48,24 @@ def __init__( infection_feedback_pmf=infection_feedback_pmf_rv, ) - self.p_ed_ar_proc = ARProcess() self.ar_diff = DifferencedProcess( fundamental_process=ARProcess(), differencing_order=1, ) - self.right_truncation_cdf_rv = TransformedVariable( - "right_truncation_cdf", right_truncation_pmf_rv, jnp.cumsum - ) self.autoreg_rt_rv = autoreg_rt_rv self.eta_sd_rv = eta_sd_rv self.log_r_mu_intercept_rv = log_r_mu_intercept_rv self.generation_interval_pmf_rv = generation_interval_pmf_rv self.infection_feedback_pmf_rv = infection_feedback_pmf_rv - self.p_ed_mean_rv = p_ed_mean_rv - self.p_ed_w_sd_rv = p_ed_w_sd_rv - self.autoreg_p_ed_rv = autoreg_p_ed_rv - self.ed_wday_effect_rv = ed_wday_effect_rv - self.inf_to_ed_rv = inf_to_ed_rv - self.phi_rv = phi_rv - self.state_pop = state_pop - self.n_initialization_points = n_initialization_points - return None - def validate(self): # numpydoc ignore=GL08 - return None + def validate(self): + pass - def sample( - self, - n_observed_disease_ed_visits_datapoints=None, - n_observed_hospital_admissions_datapoints=None, - data_observed_disease_ed_visits=None, - data_observed_disease_hospital_admissions=None, - right_truncation_offset=None, - ): # numpydoc ignore=GL08 - if ( - n_observed_disease_ed_visits_datapoints is None - and data_observed_disease_ed_visits is None - ): - raise ValueError( - "Either n_observed_disease_ed_visits_datapoints or data_observed_disease_ed_visits " - "must be passed." - ) - elif ( - n_observed_disease_ed_visits_datapoints is not None - and data_observed_disease_ed_visits is not None - ): - raise ValueError( - "Cannot pass both n_observed_disease_ed_visits_datapoints and data_observed_disease_ed_visits." - ) - elif n_observed_disease_ed_visits_datapoints is None: - n_observed_disease_ed_visits_datapoints = len( - data_observed_disease_ed_visits - ) - - if ( - n_observed_hospital_admissions_datapoints is None - and data_observed_disease_hospital_admissions is not None - ): - n_observed_hospital_admissions_datapoints = len( - data_observed_disease_hospital_admissions - ) - - n_weeks_post_init = n_observed_disease_ed_visits_datapoints // 7 + 1 + def sample(self, n_days_post_init: int, n_weeks_post_init: int): + """ + Sample latent infections. + """ i0 = self.infection_initialization_process() eta_sd = self.eta_sd_rv() @@ -140,7 +87,7 @@ def sample( rtu = repeat_until_n( data=jnp.exp(log_rtu_weekly), - n_timepoints=n_observed_disease_ed_visits_datapoints, + n_timepoints=n_days_post_init, offset=0, period_size=7, ) @@ -163,6 +110,46 @@ def sample( numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt) numpyro.deterministic("latent_infections", latent_infections) + return latent_infections + + +class EDVisitObservationProcess(RandomVariable): + def __init__( + self, + p_ed_mean_rv: RandomVariable, + p_ed_w_sd_rv: RandomVariable, + autoreg_p_ed_rv: RandomVariable, + ed_wday_effect_rv: RandomVariable, + inf_to_ed_rv: RandomVariable, + ed_neg_bin_concentration_rv: RandomVariable, + ed_right_truncation_pmf_rv: RandomVariable, + ) -> None: + self.p_ed_ar_proc = ARProcess() + self.p_ed_mean_rv = p_ed_mean_rv + self.p_ed_w_sd_rv = p_ed_w_sd_rv + self.autoreg_p_ed_rv = autoreg_p_ed_rv + self.ed_wday_effect_rv = ed_wday_effect_rv + self.inf_to_ed_rv = inf_to_ed_rv + self.ed_right_truncation_cdf_rv = TransformedVariable( + "ed_right_truncation_cdf", ed_right_truncation_pmf_rv, jnp.cumsum + ) + self.ed_neg_bin_concentration_rv = ed_neg_bin_concentration_rv + + def validate(self): + pass + + def sample( + self, + latent_infections: ArrayLike, + population_size: int, + data_observed_disease_ed_visits: ArrayLike, + n_observed_disease_ed_visits_datapoints: int, + n_weeks_post_init: int, + right_truncation_offset: ArrayLike = None, + ) -> ArrayLike: + """ + Observe and/or predict ED visit values + """ p_ed_mean = self.p_ed_mean_rv() p_ed_w_sd = self.p_ed_w_sd_rv() autoreg_p_ed = self.autoreg_p_ed_rv() @@ -188,7 +175,9 @@ def sample( transformation.SigmoidTransform()(p_ed_ar + p_ed_mean), repeats=7, )[:n_observed_disease_ed_visits_datapoints] - # this is only applied after the ed visits are generated, not to all the latent infections. This is why we cannot apply the iedr in compute_delay_ascertained_incidence + # this is only applied after the ed visits are generated, not to all + # the latent infections. This is why we cannot apply the iedr in + # compute_delay_ascertained_incidence # see https://github.com/CDCgov/ww-inference-model/issues/43 numpyro.deterministic("iedr", iedr) @@ -207,12 +196,15 @@ def sample( )[-n_observed_disease_ed_visits_datapoints:] latent_ed_visits_final = ( - potential_latent_ed_visits * iedr * ed_wday_effect * self.state_pop + potential_latent_ed_visits + * iedr + * ed_wday_effect + * population_size ) if right_truncation_offset is not None: prop_already_reported_tail = jnp.flip( - self.right_truncation_cdf_rv()[right_truncation_offset:] + self.ed_right_truncation_cdf_rv()[right_truncation_offset:] ) n_points_to_prepend = ( n_observed_disease_ed_visits_datapoints @@ -231,7 +223,8 @@ def sample( latent_ed_visits_now = latent_ed_visits_final ed_visit_obs_rv = NegativeBinomialObservation( - "observed_ed_visits", concentration_rv=self.phi_rv + "observed_ed_visits", + concentration_rv=self.ed_neg_bin_concentration_rv, ) observed_ed_visits = ed_visit_obs_rv( @@ -239,17 +232,177 @@ def sample( obs=data_observed_disease_ed_visits, ) - if n_observed_hospital_admissions_datapoints is not None: - hospital_admissions_obs_rv = PoissonObservation( - "observed_hospital_admissions" + return observed_ed_visits + + +class HospAdmitObservationProcess(RandomVariable): + def __init__( + self, + inf_to_hosp_admit_rv: RandomVariable, + hosp_admit_neg_bin_concentration_rv: RandomVariable, + ): + self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv + self.hosp_admit_neg_bin_concentration_rv = ( + hosp_admit_neg_bin_concentration_rv + ) + + def validate(self): + pass + + def sample( + self, + latent_infections: ArrayLike, + population_size: int, + n_observed_hospital_admissions_datapoints: int, + data_observed_disease_hospital_admissions: ArrayLike | None = None, + ) -> ArrayLike: + """ + Observe and/or predict incident hospital admissions. + """ + inf_to_hosp_admit = self.inf_to_hosp_admit_rv() + + latent_hospital_admissions = compute_delay_ascertained_incidence( + p_observed_given_incident=1, + latent_incidence=latent_infections, + delay_incidence_to_observation_pmf=inf_to_hosp_admit, + )[-n_observed_hospital_admissions_datapoints:] + + hospital_admissions_obs_rv = NegativeBinomialObservation( + "observed_hospital_admissions", + concentration_rv=self.hosp_admit_neg_bin_concentration_rv, + ) + predicted_admissions = latent_hospital_admissions + sampled_admissions = hospital_admissions_obs_rv( + mu=predicted_admissions, + obs=data_observed_disease_hospital_admissions, + ) + return sampled_admissions + + +class WastewaterObservationProcess(RandomVariable): + """ + Placeholder for wasteater obs process + """ + + def __init__(self) -> None: + pass + + def sample(self): + pass + + def validate(self): + pass + + +class PyrenewHEWModel(Model): # numpydoc ignore=GL08 + def __init__( + self, + population_size: int, + latent_infection_process_rv: LatentInfectionProcess, + ed_visit_obs_process_rv: EDVisitObservationProcess, + hosp_admit_obs_process_rv: HospAdmitObservationProcess, + wastewater_obs_process_rv: WastewaterObservationProcess, + ) -> None: # numpydoc ignore=GL08 + self.population_size = population_size + self.latent_infection_process_rv = latent_infection_process_rv + self.ed_visit_obs_process_rv = ed_visit_obs_process_rv + self.hosp_admit_obs_process_rv = hosp_admit_obs_process_rv + return None + + def validate(self): # numpydoc ignore=GL08 + return None + + def sample( + self, + n_observed_disease_ed_visits_datapoints=None, + n_observed_hospital_admissions_datapoints=None, + data_observed_disease_ed_visits=None, + data_observed_disease_hospital_admissions=None, + right_truncation_offset=None, + ) -> ArrayLike: # numpydoc ignore=GL08 + if ( + n_observed_disease_ed_visits_datapoints is None + and data_observed_disease_ed_visits is None + ): + raise ValueError( + "Either n_observed_disease_ed_visits_datapoints or " + "data_observed_disease_ed_visits " + "must be passed." ) - data_observed_disease_hospital_admissions = ( - hospital_admissions_obs_rv( - mu=jnp.ones(n_observed_hospital_admissions_datapoints) + 50 - ) + elif ( + n_observed_disease_ed_visits_datapoints is not None + and data_observed_disease_ed_visits is not None + ): + raise ValueError( + "Cannot pass both n_observed_disease_ed_visits_datapoints " + "and data_observed_disease_ed_visits." + ) + elif n_observed_disease_ed_visits_datapoints is None: + n_observed_disease_ed_visits_datapoints = len( + data_observed_disease_ed_visits ) - return observed_ed_visits + if ( + n_observed_hospital_admissions_datapoints is None + and data_observed_disease_hospital_admissions is not None + ): + n_observed_hospital_admissions_datapoints = len( + data_observed_disease_hospital_admissions + ) + elif n_observed_hospital_admissions_datapoints is None: + n_observed_hospital_admissions_datapoints = 0 + + n_weeks_post_init = n_observed_disease_ed_visits_datapoints // 7 + 1 + n_days_post_init = n_observed_disease_ed_visits_datapoints + + latent_infections = self.latent_infection_process_rv( + n_days_post_init=n_days_post_init, + n_weeks_post_init=n_weeks_post_init, + ) + + sample_ed_visits = True + sample_admissions = n_observed_hospital_admissions_datapoints > 0 + sample_wastewater = False + + sampled_ed_visits, sampled_admissions, sampled_wastewater = ( + None, + None, + None, + ) + + if sample_ed_visits: + sampled_ed_visits = self.ed_visit_obs_process_rv( + latent_infections=latent_infections, + population_size=self.population_size, + data_observed_disease_ed_visits=( + data_observed_disease_ed_visits + ), + n_observed_disease_ed_visits_datapoints=( + n_observed_disease_ed_visits_datapoints + ), + n_weeks_post_init=n_weeks_post_init, + right_truncation_offset=right_truncation_offset, + ) + + if sample_admissions: + sampled_admissions = self.hosp_admit_obs_process_rv( + latent_infections=latent_infections, + population_size=self.population_size, + n_observed_hospital_admissions_datapoints=( + n_observed_hospital_admissions_datapoints + ), + data_observed_disease_hospital_admissions=( + data_observed_disease_hospital_admissions + ), + ) + if sample_wastewater: + sampled_wastewater = self.wastewater_obs_process_rv() + + return { + "ed_visits": sampled_ed_visits, + "admissions": sampled_admissions, + "wasewater": sampled_wastewater, + } def create_pyrenew_hew_model_from_stan_data(stan_data_file): @@ -364,10 +517,27 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file): inv_sqrt_phi_prior_mean = stan_data["inv_sqrt_phi_prior_mean"] inv_sqrt_phi_prior_sd = stan_data["inv_sqrt_phi_prior_sd"] - phi_rv = TransformedVariable( - "phi", + ed_neg_bin_concentration_rv = TransformedVariable( + "ed_visit_neg_bin_concentration", + DistributionalVariable( + "inv_sqrt_ed_visit_neg_bin_conc", + dist.TruncatedNormal( + loc=inv_sqrt_phi_prior_mean, + scale=inv_sqrt_phi_prior_sd, + low=1 / jnp.sqrt(5000), + ), + ), + transforms=transformation.PowerTransform(-2), + ) + + inf_to_hosp_admit_rv = DeterministicVariable( + "inf_to_hosp_admit", jnp.array(stan_data["inf_to_hosp"]) + ) + + hosp_admit_neg_bin_concentration_rv = TransformedVariable( + "hosp_admit_neg_bin_concentration", DistributionalVariable( - "inv_sqrt_phi", + "inv_sqrt_hosp_admit_neg_bin_conc", dist.TruncatedNormal( loc=inv_sqrt_phi_prior_mean, scale=inv_sqrt_phi_prior_sd, @@ -379,14 +549,14 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file): uot = stan_data["uot"] uot = len(jnp.array(stan_data["inf_to_hosp"])) - state_pop = stan_data["state_pop"] + population_size = stan_data["state_pop"] data_observed_disease_ed_visits = jnp.array(stan_data["hosp"]) - right_truncation_pmf_rv = DeterministicVariable( - "right_truncation_pmf", jnp.array(1) + ed_right_truncation_pmf_rv = DeterministicVariable( + "ed_visit_right_truncation_pmf", jnp.array(1) ) - my_model = pyrenew_hew_model( - state_pop=state_pop, + + my_latent_infection_model = LatentInfectionProcess( i0_first_obs_n_rv=i0_first_obs_n_rv, initialization_rate_rv=initialization_rate_rv, log_r_mu_intercept_rv=log_r_mu_intercept_rv, @@ -395,14 +565,34 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file): generation_interval_pmf_rv=generation_interval_pmf_rv, infection_feedback_pmf_rv=infection_feedback_pmf_rv, infection_feedback_strength_rv=inf_feedback_strength_rv, + n_initialization_points=uot, + ) + + my_ed_visit_obs_model = EDVisitObservationProcess( p_ed_mean_rv=p_ed_mean_rv, p_ed_w_sd_rv=p_ed_w_sd_rv, autoreg_p_ed_rv=autoreg_p_ed_rv, ed_wday_effect_rv=ed_wday_effect_rv, inf_to_ed_rv=inf_to_ed_rv, - phi_rv=phi_rv, - right_truncation_pmf_rv=right_truncation_pmf_rv, - n_initialization_points=uot, + ed_neg_bin_concentration_rv=ed_neg_bin_concentration_rv, + ed_right_truncation_pmf_rv=ed_right_truncation_pmf_rv, + ) + + my_hosp_admit_obs_model = HospAdmitObservationProcess( + inf_to_hosp_admit_rv=inf_to_hosp_admit_rv, + hosp_admit_neg_bin_concentration_rv=( + hosp_admit_neg_bin_concentration_rv + ), + ) + + my_wastewater_obs_model = WastewaterObservationProcess() + + my_model = PyrenewHEWModel( + population_size=population_size, + latent_infection_process_rv=my_latent_infection_model, + ed_visit_obs_process_rv=my_ed_visit_obs_model, + hosp_admit_obs_process_rv=my_hosp_admit_obs_model, + wastewater_obs_process_rv=my_wastewater_obs_model, ) return my_model, data_observed_disease_ed_visits