diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index b77f607e..82d87388 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -6,6 +6,7 @@ import numpyro.distributions as dist import pyrenew.transformation as transformation from jax.typing import ArrayLike +from numpyro.infer.reparam import LocScaleReparam from pyrenew.arrayutils import repeat_until_n, tile_until_n from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.deterministic import DeterministicVariable @@ -34,15 +35,15 @@ def __init__( infection_feedback_strength_rv: RandomVariable, infection_feedback_pmf_rv: RandomVariable, n_initialization_points: int, + pop_fraction: ArrayLike = jnp.array([1]), + autoreg_rt_subpop_rv: RandomVariable = None, + sigma_rt_rv: RandomVariable = None, + sigma_i_first_obs_rv: RandomVariable = None, + sigma_initial_exp_growth_rate_rv: RandomVariable = None, + offset_ref_logit_i_first_obs_rv: RandomVariable = None, + offset_ref_initial_exp_growth_rate_rv: RandomVariable = None, + offset_ref_log_rt_rv: RandomVariable = None, ) -> None: - self.infection_initialization_process = InfectionInitializationProcess( - "I0_initialization", - i0_first_obs_n_rv, - InitializeInfectionsExponentialGrowth( - n_initialization_points, initialization_rate_rv, t_pre_init=0 - ), - ) - self.inf_with_feedback_proc = InfectionsWithFeedback( infection_feedback_strength=infection_feedback_strength_rv, infection_feedback_pmf=infection_feedback_pmf_rv, @@ -53,11 +54,27 @@ def __init__( differencing_order=1, ) + self.log_r_mu_intercept_rv = log_r_mu_intercept_rv self.autoreg_rt_rv = autoreg_rt_rv self.eta_sd_rv = eta_sd_rv - self.log_r_mu_intercept_rv = log_r_mu_intercept_rv self.generation_interval_pmf_rv = generation_interval_pmf_rv self.infection_feedback_pmf_rv = infection_feedback_pmf_rv + self.i0_first_obs_n_rv = i0_first_obs_n_rv + self.initialization_rate_rv = initialization_rate_rv + self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv + self.offset_ref_initial_exp_growth_rate_rv = ( + offset_ref_initial_exp_growth_rate_rv + ) + self.offset_ref_log_rt_rv = offset_ref_log_rt_rv + self.autoreg_rt_subpop_rv = autoreg_rt_subpop_rv + self.sigma_rt_rv = sigma_rt_rv + self.sigma_i_first_obs_rv = sigma_i_first_obs_rv + self.sigma_initial_exp_growth_rate_rv = ( + sigma_initial_exp_growth_rate_rv + ) + self.n_initialization_points = n_initialization_points + self.pop_fraction = pop_fraction + self.n_subpops = len(pop_fraction) def validate(self): pass @@ -66,8 +83,6 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int): """ Sample latent infections. """ - i0 = self.infection_initialization_process() - eta_sd = self.eta_sd_rv() autoreg_rt = self.autoreg_rt_rv() log_r_mu_intercept = self.log_r_mu_intercept_rv() @@ -85,28 +100,146 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int): noise_name="rtu_weekly_diff_first_diff_ar_process_noise", ) - rtu = repeat_until_n( - data=jnp.exp(log_rtu_weekly), - n_timepoints=n_days_post_init, - offset=0, - period_size=7, + i0_first_obs_n = self.i0_first_obs_n_rv() + initial_exp_growth_rate = self.initialization_rate_rv() + if self.n_subpops == 1: + i_first_obs_over_n_subpop = i0_first_obs_n + initial_exp_growth_rate_subpop = initial_exp_growth_rate + log_rtu_weekly_subpop = log_rtu_weekly[:, jnp.newaxis] + else: + i_first_obs_over_n_ref_subpop = transformation.SigmoidTransform()( + transformation.logit(i0_first_obs_n) + + self.offset_ref_logit_i_first_obs_rv(), + ) + initial_exp_growth_rate_ref_subpop = ( + initial_exp_growth_rate + + self.offset_ref_initial_exp_growth_rate_rv() + ) + + log_rtu_weekly_ref_subpop = ( + log_rtu_weekly + self.offset_ref_log_rt_rv() + ) + i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable( + "i_first_obs_over_n_non_ref_subpop", + DistributionalVariable( + "i_first_obs_over_n_non_ref_subpop_raw", + dist.Normal( + transformation.logit(i0_first_obs_n), + self.sigma_i_first_obs_rv(), + ), + reparam=LocScaleReparam(0), + ), + transforms=transformation.SigmoidTransform(), + ) + initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable( + "initial_exp_growth_rate_non_ref_subpop_raw", + dist.Normal( + initial_exp_growth_rate, + self.sigma_initial_exp_growth_rate_rv(), + ), + reparam=LocScaleReparam(0), + ) + + autoreg_rt_subpop = self.autoreg_rt_subpop_rv() + sigma_rt = self.sigma_rt_rv() + rtu_subpop_ar_init_rv = DistributionalVariable( + "rtu_subpop_ar_init", + dist.Normal( + 0, + sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_subpop, 2)), + ), + ) + + with numpyro.plate("n_subpops", self.n_subpops - 1): + initial_exp_growth_rate_non_ref_subpop = ( + initial_exp_growth_rate_non_ref_subpop_rv() + ) + i_first_obs_over_n_non_ref_subpop = ( + i_first_obs_over_n_non_ref_subpop_rv() + ) + rtu_subpop_ar_init = rtu_subpop_ar_init_rv() + + i_first_obs_over_n_subpop = jnp.hstack( + [ + i_first_obs_over_n_ref_subpop, + i_first_obs_over_n_non_ref_subpop, + ] + ) + initial_exp_growth_rate_subpop = jnp.hstack( + [ + initial_exp_growth_rate_ref_subpop, + initial_exp_growth_rate_non_ref_subpop, + ] + ) + + rtu_subpop_ar_proc = ARProcess() + rtu_subpop_ar_weekly = rtu_subpop_ar_proc( + noise_name="rtu_ar_proc", + n=n_weeks_post_init, + init_vals=rtu_subpop_ar_init[jnp.newaxis], + autoreg=autoreg_rt_subpop[jnp.newaxis], + noise_sd=sigma_rt, + ) + + log_rtu_weekly_non_ref_subpop = ( + rtu_subpop_ar_weekly + log_rtu_weekly[:, jnp.newaxis] + ) + log_rtu_weekly_subpop = jnp.concat( + [ + log_rtu_weekly_ref_subpop[:, jnp.newaxis], + log_rtu_weekly_non_ref_subpop, + ], + axis=1, + ) + + rtu_subpop = jnp.squeeze( + jnp.repeat( + jnp.exp(log_rtu_weekly_subpop), + repeats=7, + axis=0, + )[:n_days_post_init, :] + ) + + i0_subpop_rv = DeterministicVariable( + "i0_subpop", i_first_obs_over_n_subpop + ) + initial_exp_growth_rate_subpop_rv = DeterministicVariable( + "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop + ) + infection_initialization_process = InfectionInitializationProcess( + "I0_initialization", + i0_subpop_rv, + InitializeInfectionsExponentialGrowth( + self.n_initialization_points, + initial_exp_growth_rate_subpop_rv, + t_pre_init=0, + ), ) generation_interval_pmf = self.generation_interval_pmf_rv() + i0 = infection_initialization_process() inf_with_feedback_proc_sample = self.inf_with_feedback_proc( - Rt=rtu, + Rt=rtu_subpop, I0=i0, gen_int=generation_interval_pmf, ) - latent_infections = jnp.concat( + latent_infections_subpop = jnp.concat( [ i0, inf_with_feedback_proc_sample.post_initialization_infections, ] ) - numpyro.deterministic("rtu", rtu) + + if self.n_subpops == 1: + latent_infections = latent_infections_subpop + else: + latent_infections = jnp.sum( + self.pop_fraction * latent_infections_subpop, axis=1 + ) + + numpyro.deterministic("rtu_subpop", rtu_subpop) numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt) numpyro.deterministic("latent_infections", latent_infections) diff --git a/pyrenew_hew/ww_site_level_dynamics_model.py b/pyrenew_hew/ww_site_level_dynamics_model.py index fca6d313..23b32a47 100644 --- a/pyrenew_hew/ww_site_level_dynamics_model.py +++ b/pyrenew_hew/ww_site_level_dynamics_model.py @@ -187,16 +187,13 @@ def sample( eta_sd = self.eta_sd_rv() autoreg_rt = self.autoreg_rt_rv() - log_r_t_first_obs = self.log_r_t_first_obs_rv() - - rt_init_rate_of_change_rv = DistributionalVariable( + log_r_t_first_obs = self.log_r_t_first_obs_rv() # log_r_mu_intercept + rt_init_rate_of_change = DistributionalVariable( "rt_init_rate_of_change", dist.Normal(0, eta_sd / jnp.sqrt(1 - jnp.pow(autoreg_rt, 2))), - ) - - rt_init_rate_of_change = rt_init_rate_of_change_rv() + )() - log_r_t_in_weeks = self.ar_diff_rt( + log_rtu_weekly = self.ar_diff_rt( noise_name="rtu_weekly_diff_first_diff_ar_process_noise", n=n_weeks_post_init, init_vals=jnp.array(log_r_t_first_obs), @@ -204,57 +201,50 @@ def sample( noise_sd=jnp.array(eta_sd), fundamental_process_init_vals=jnp.array(rt_init_rate_of_change), ) - numpyro.deterministic("log_r_t_in_weeks", log_r_t_in_weeks) + numpyro.deterministic("log_rtu_weekly", log_rtu_weekly) i_first_obs_over_n = self.i_first_obs_over_n_rv() - offset_ref_logit_i_first_obs = self.offset_ref_logit_i_first_obs_rv() - mean_initial_exp_growth_rate = self.mean_initial_exp_growth_rate_rv() - offset_ref_initial_exp_growth_rate = ( - self.offset_ref_initial_exp_growth_rate_rv() - ) - i_first_obs_over_n_ref_subpop = transforms.SigmoidTransform()( transforms.logit(i_first_obs_over_n) - + jnp.where(self.n_subpops > 1, offset_ref_logit_i_first_obs, 0) + + jnp.where( + self.n_subpops > 1, self.offset_ref_logit_i_first_obs_rv(), 0 + ) ) initial_exp_growth_rate_ref_subpop = ( mean_initial_exp_growth_rate + jnp.where( - self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 + self.n_subpops > 1, + self.offset_ref_initial_exp_growth_rate_rv(), + 0, ) ) - - offset_ref_log_r_t = self.offset_ref_log_r_t_rv() - log_rtu_ref_subpop_in_week = log_r_t_in_weeks + jnp.where( - self.n_subpops > 1, offset_ref_log_r_t, 0 + log_rtu_weekly_ref_subpop = log_rtu_weekly + jnp.where( + self.n_subpops > 1, self.offset_ref_log_rt_rv(), 0 ) if self.n_subpops == 1: i_first_obs_over_n_subpop = i_first_obs_over_n_ref_subpop initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop - log_rtu_subpop_in_week = log_rtu_ref_subpop_in_week[:, jnp.newaxis] + log_rtu_weekly_subpop = log_rtu_weekly_ref_subpop[:, jnp.newaxis] else: - sigma_i_first_obs = self.sigma_i_first_obs_rv() i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable( "i_first_obs_over_n_non_ref_subpop", DistributionalVariable( "i_first_obs_over_n_non_ref_subpop_raw", dist.Normal( - transforms.logit(i_first_obs_over_n), sigma_i_first_obs + transforms.logit(i_first_obs_over_n), + self.sigma_i_first_obs_rv(), ), reparam=LocScaleReparam(0), ), transforms=transforms.SigmoidTransform(), ) - sigma_initial_exp_growth_rate = ( - self.sigma_initial_exp_growth_rate_rv() - ) initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable( "initial_exp_growth_rate_non_ref_subpop_raw", dist.Normal( mean_initial_exp_growth_rate, - sigma_initial_exp_growth_rate, + self.sigma_initial_exp_growth_rate_rv(), ), reparam=LocScaleReparam(0), ) @@ -300,13 +290,13 @@ def sample( noise_sd=sigma_rt, ) numpyro.deterministic("rtu_subpop_ar_weekly", rtu_subpop_ar_weekly) - log_rtu_non_ref_subpop_in_week = ( - rtu_subpop_ar_weekly + log_r_t_in_weeks[:, jnp.newaxis] + log_rtu_weekly_non_ref_subpop = ( + rtu_subpop_ar_weekly + log_rtu_weekly[:, jnp.newaxis] ) - log_rtu_subpop_in_week = jnp.concat( + log_rtu_weekly_subpop = jnp.concat( [ - log_rtu_ref_subpop_in_week[:, jnp.newaxis], - log_rtu_non_ref_subpop_in_week, + log_rtu_weekly_ref_subpop[:, jnp.newaxis], + log_rtu_weekly_non_ref_subpop, ], axis=1, ) @@ -326,7 +316,7 @@ def sample( rtu_subpop = jnp.squeeze( jnp.repeat( - jnp.exp(log_rtu_subpop_in_week), + jnp.exp(log_rtu_weekly_subpop), repeats=7, axis=0, )[:n_datapoints, :] @@ -360,24 +350,22 @@ def sample( gen_int=generation_interval_pmf, ) - new_i_subpop = jnp.atleast_2d( - jnp.concat( - [ - i0, - inf_with_feedback_proc_sample.post_initialization_infections, - ] - ) + latent_infections_subpop = jnp.concat( + [ + i0, + inf_with_feedback_proc_sample.post_initialization_infections, + ] ) - if new_i_subpop.shape[0] == 1: - new_i_subpop = new_i_subpop.T - r_subpop_t = inf_with_feedback_proc_sample.rt - numpyro.deterministic("r_subpop_t", r_subpop_t) + if self.n_subpops == 1: + latent_infections = latent_infections_subpop + else: + latent_infections = jnp.sum( + self.pop_fraction * latent_infections_subpop, axis=1 + ) - state_inf_per_capita = jnp.sum( - self.pop_fraction * new_i_subpop, axis=1 - ) - numpyro.deterministic("state_inf_per_capita", state_inf_per_capita) + numpyro.deterministic("latent_infections", latent_infections) + numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt) # Hospital admission component p_hosp_mean = self.p_hosp_mean_rv() @@ -416,7 +404,7 @@ def sample( potential_latent_hospital_admissions = ( compute_delay_ascertained_incidence( p_observed_given_incident=1, - latent_incidence=state_inf_per_capita, + latent_incidence=latent_infections, delay_incidence_to_observation_pmf=inf_to_hosp, )[-n_datapoints:] ) @@ -455,7 +443,7 @@ def batch_colvolve_fn(m): return jnp.convolve(m, s, mode="valid") model_net_i = jax.vmap(batch_colvolve_fn, in_axes=1, out_axes=1)( - new_i_subpop + latent_infections_subpop )[-n_datapoints:, :] numpyro.deterministic("model_net_i", model_net_i) @@ -536,7 +524,7 @@ def batch_colvolve_fn(m): ) state_model_net_i = jnp.convolve( - state_inf_per_capita, s, mode="valid" + latent_infections, s, mode="valid" )[-n_datapoints:] numpyro.deterministic("state_model_net_i", state_model_net_i) @@ -553,9 +541,9 @@ def batch_colvolve_fn(m): ) state_rt = ( - state_inf_per_capita[-n_datapoints:] + latent_infections[-n_datapoints:] / jnp.convolve( - state_inf_per_capita, + latent_infections, jnp.hstack( (jnp.array([0]), jnp.array(generation_interval_pmf)) ), diff --git a/tests/test_LatentInfectionProcess.py b/tests/test_LatentInfectionProcess.py new file mode 100644 index 00000000..40dd192f --- /dev/null +++ b/tests/test_LatentInfectionProcess.py @@ -0,0 +1,112 @@ +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from pyrenew.arrayutils import repeat_until_n +from pyrenew.deterministic import DeterministicVariable +from pyrenew.latent import ( + InfectionInitializationProcess, + InfectionsWithFeedback, + InitializeInfectionsExponentialGrowth, +) +from pyrenew.process import ARProcess, DifferencedProcess +from pyrenew.randomvariable import DistributionalVariable + +from pyrenew_hew.pyrenew_hew_model import LatentInfectionProcess + + +def test_LatentInfectionProcess(): + """ + Tests when there is a single sub-population, + the hierarchical construct and manual construction + without the hierarchical component are equivalent. + """ + i0_first_obs_n_rv = DeterministicVariable("i0_first_obs_n_rv", 1e-6) + initialization_rate_rv = DeterministicVariable("rate", 0.001) + log_r_mu_intercept_rv = DeterministicVariable("log_r_mu_intercept", 0.08) + eta_sd_rv = DeterministicVariable("eta_sd", 0) + autoreg_rt_rv = DeterministicVariable("autoreg_rt", 0.4) + generation_interval_pmf_rv = DeterministicVariable( + "generation_interval_pmf", jnp.array([0.25, 0.25, 0.25, 0.25]) + ) + infection_feedback_pmf_rv = DeterministicVariable( + "infection_feedback_pmf", jnp.array([0.25, 0.25, 0.25, 0.25]) + ) + infection_feedback_strength_rv = DeterministicVariable("inf_feedback", -2) + n_initialization_points = 10 + n_days_post_init = 14 + n_weeks_post_init = 2 + + my_latent_infection_model = LatentInfectionProcess( + i0_first_obs_n_rv=i0_first_obs_n_rv, + initialization_rate_rv=initialization_rate_rv, + log_r_mu_intercept_rv=log_r_mu_intercept_rv, + autoreg_rt_rv=autoreg_rt_rv, + eta_sd_rv=eta_sd_rv, # sd of random walk for ar process, + generation_interval_pmf_rv=generation_interval_pmf_rv, + infection_feedback_pmf_rv=infection_feedback_pmf_rv, + infection_feedback_strength_rv=infection_feedback_strength_rv, + n_initialization_points=n_initialization_points, + ) + + with numpyro.handlers.seed(rng_seed=223): + latent_inf_w_hierarchical_effects = my_latent_infection_model( + n_days_post_init=n_days_post_init, + n_weeks_post_init=n_weeks_post_init, + ) + + # Calculate latent infections without hierarchical dynamics + i0 = InfectionInitializationProcess( + "I0_initialization", + i0_first_obs_n_rv, + InitializeInfectionsExponentialGrowth( + n_initialization_points, initialization_rate_rv, t_pre_init=0 + ), + )() + + inf_with_feedback_proc = InfectionsWithFeedback( + infection_feedback_strength=infection_feedback_strength_rv, + infection_feedback_pmf=infection_feedback_pmf_rv, + ) + + ar_diff = DifferencedProcess( + fundamental_process=ARProcess(), + differencing_order=1, + ) + + rt_init_rate_of_change = DistributionalVariable( + "rt_init_rate_of_change", + dist.Normal( + 0, eta_sd_rv() / jnp.sqrt(1 - jnp.pow(autoreg_rt_rv(), 2)) + ), + )() + + log_rtu_weekly = ar_diff( + n=2, + init_vals=jnp.array(log_r_mu_intercept_rv()), + autoreg=jnp.array(autoreg_rt_rv()), + noise_sd=jnp.array(eta_sd_rv()), + fundamental_process_init_vals=jnp.array(rt_init_rate_of_change), + noise_name="rtu_weekly_diff_first_diff_ar_process_noise", + ) + rtu = repeat_until_n( + data=jnp.exp(log_rtu_weekly), + n_timepoints=n_days_post_init, + offset=0, + period_size=7, + ) + inf_with_feedback_proc_sample = inf_with_feedback_proc( + Rt=rtu, + I0=i0, + gen_int=generation_interval_pmf_rv(), + ) + + latent_inf_wo_hierarchical_effects = jnp.concat( + [ + i0, + inf_with_feedback_proc_sample.post_initialization_infections, + ] + ) + + assert jnp.allclose( + latent_inf_w_hierarchical_effects, latent_inf_wo_hierarchical_effects + )