From 50dc084c037d0d630675ee56905aeedb68250d3e Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 10 Jan 2025 16:34:11 -0500 Subject: [PATCH] reorg --- pyrenew_hew/pyrenew_hew_model.py | 124 +++++++++++++++++-------------- 1 file changed, 69 insertions(+), 55 deletions(-) diff --git a/pyrenew_hew/pyrenew_hew_model.py b/pyrenew_hew/pyrenew_hew_model.py index 27fefe9a..a7b834af 100644 --- a/pyrenew_hew/pyrenew_hew_model.py +++ b/pyrenew_hew/pyrenew_hew_model.py @@ -5,9 +5,9 @@ import numpyro import numpyro.distributions as dist import numpyro.distributions.transforms as transforms -from numpyro.infer.reparam import LocScaleReparam import pyrenew.transformation as transformation from jax.typing import ArrayLike +from numpyro.infer.reparam import LocScaleReparam from pyrenew.arrayutils import repeat_until_n, tile_until_n from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.deterministic import DeterministicVariable @@ -27,17 +27,14 @@ class LatentInfectionProcess(RandomVariable): def __init__( self, + i0_first_obs_n_rv: RandomVariable, + initialization_rate_rv: RandomVariable, log_r_mu_intercept_rv: RandomVariable, autoreg_rt_rv: RandomVariable, # ar coefficient of AR(1) process on R'(t) eta_sd_rv: RandomVariable, # sd of random walk for ar process on R'(t) generation_interval_pmf_rv: RandomVariable, infection_feedback_strength_rv: RandomVariable, infection_feedback_pmf_rv: RandomVariable, - i_first_obs_over_n_rv: RandomVariable, - mean_initial_exp_growth_rate_rv: RandomVariable, - offset_ref_logit_i_first_obs_rv: RandomVariable, - offset_ref_initial_exp_growth_rate_rv: RandomVariable, - offset_ref_log_r_t_rv: RandomVariable, n_initialization_points: int, pop_fraction: float = 1, n_subpops: int = 1, @@ -45,8 +42,10 @@ def __init__( sigma_rt_rv: RandomVariable = None, sigma_i_first_obs_rv: RandomVariable = None, sigma_initial_exp_growth_rate_rv: RandomVariable = None, + offset_ref_logit_i_first_obs_rv: RandomVariable = None, + offset_ref_initial_exp_growth_rate_rv: RandomVariable = None, + offset_ref_log_rt_rv: RandomVariable = None, ) -> None: - self.inf_with_feedback_proc = InfectionsWithFeedback( infection_feedback_strength=infection_feedback_strength_rv, infection_feedback_pmf=infection_feedback_pmf_rv, @@ -62,18 +61,19 @@ def __init__( self.eta_sd_rv = eta_sd_rv self.generation_interval_pmf_rv = generation_interval_pmf_rv self.infection_feedback_pmf_rv = infection_feedback_pmf_rv - - self.i_first_obs_over_n_rv = i_first_obs_over_n_rv - self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv + self.i0_first_obs_n_rv = i0_first_obs_n_rv + self.initialization_rate_rv = initialization_rate_rv self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv self.offset_ref_initial_exp_growth_rate_rv = ( offset_ref_initial_exp_growth_rate_rv ) - self.offset_ref_log_r_t_rv = offset_ref_log_r_t_rv + self.offset_ref_log_rt_rv = offset_ref_log_rt_rv self.autoreg_rt_subpop_rv = autoreg_rt_subpop_rv self.sigma_rt_rv = sigma_rt_rv self.sigma_i_first_obs_rv = sigma_i_first_obs_rv - self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv + self.sigma_initial_exp_growth_rate_rv = ( + sigma_initial_exp_growth_rate_rv + ) self.n_initialization_points = n_initialization_points self.pop_fraction = pop_fraction self.n_subpops = n_subpops @@ -102,50 +102,47 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int): noise_name="rtu_weekly_diff_first_diff_ar_process_noise", ) - i_first_obs_over_n = self.i_first_obs_over_n_rv() - offset_ref_logit_i_first_obs = self.offset_ref_logit_i_first_obs_rv() - - mean_initial_exp_growth_rate = self.mean_initial_exp_growth_rate_rv() - offset_ref_initial_exp_growth_rate = ( - self.offset_ref_initial_exp_growth_rate_rv() - ) - - i_first_obs_over_n_ref_subpop = transforms.SigmoidTransform()( - transforms.logit(i_first_obs_over_n) - + jnp.where(self.n_subpops > 1, offset_ref_logit_i_first_obs, 0) - ) - initial_exp_growth_rate_ref_subpop = mean_initial_exp_growth_rate + jnp.where( - self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0 - ) - - offset_ref_log_r_t = self.offset_ref_log_r_t_rv() - log_rtu_ref_subpop_in_week = log_rtu_weekly + jnp.where( - self.n_subpops > 1, offset_ref_log_r_t, 0 - ) - if self.n_subpops == 1: - i_first_obs_over_n_subpop = i_first_obs_over_n_ref_subpop - initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop - log_rtu_weekly_subpop = log_rtu_ref_subpop_in_week[:, jnp.newaxis] + i_first_obs_over_n_subpop = self.i0_first_obs_n_rv() + initial_exp_growth_rate_subpop = self.initialization_rate_rv() + log_rtu_weekly_subpop = log_rtu_weekly[:, jnp.newaxis] else: - sigma_i_first_obs = self.sigma_i_first_obs_rv() + i_first_obs_over_n_ref_subpop = transforms.SigmoidTransform()( + transforms.logit(self.i0_first_obs_n_rv()) + + jnp.where( + self.n_subpops > 1, + self.offset_ref_logit_i_first_obs_rv(), + 0, + ) + ) + initial_exp_growth_rate_ref_subpop = ( + self.initialization_rate_rv() + + jnp.where( + self.n_subpops > 1, + self.offset_ref_initial_exp_growth_rate_rv(), + 0, + ) + ) + log_rtu_weekly_ref_subpop = log_rtu_weekly + jnp.where( + self.n_subpops > 1, self.offset_ref_log_rt_rv(), 0 + ) i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable( "i_first_obs_over_n_non_ref_subpop", DistributionalVariable( "i_first_obs_over_n_non_ref_subpop_raw", dist.Normal( - transforms.logit(i_first_obs_over_n), sigma_i_first_obs + transforms.logit(self.i0_first_obs_n_rv()), + self.sigma_i_first_obs_rv(), ), reparam=LocScaleReparam(0), ), transforms=transforms.SigmoidTransform(), ) - sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv() initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable( "initial_exp_growth_rate_non_ref_subpop_raw", dist.Normal( - mean_initial_exp_growth_rate, - sigma_initial_exp_growth_rate, + self.initialization_rate_rv(), + self.sigma_initial_exp_growth_rate_rv(), ), reparam=LocScaleReparam(0), ) @@ -191,13 +188,13 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int): noise_sd=sigma_rt, ) - log_rtu_non_ref_subpop_in_week = ( + log_rtu_weekly_non_ref_subpop = ( rtu_subpop_ar_weekly + log_rtu_weekly[:, jnp.newaxis] ) log_rtu_weekly_subpop = jnp.concat( [ - log_rtu_ref_subpop_in_week[:, jnp.newaxis], - log_rtu_non_ref_subpop_in_week, + log_rtu_weekly_ref_subpop[:, jnp.newaxis], + log_rtu_weekly_non_ref_subpop, ], axis=1, ) @@ -214,7 +211,9 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int): jnp.log(i_first_obs_over_n_subpop) - self.unobs_time * initial_exp_growth_rate_subpop ) - i0_subpop_rv = DeterministicVariable("i0_subpop", jnp.exp(log_i0_subpop)) + i0_subpop_rv = DeterministicVariable( + "i0_subpop", jnp.exp(log_i0_subpop) + ) initial_exp_growth_rate_subpop_rv = DeterministicVariable( "initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop ) @@ -248,15 +247,15 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int): if latent_infections_subpop.shape[0] == 1: latent_infections_subpop = latent_infections_subpop.T - latent_infections_total = jnp.sum( + latent_infections = jnp.sum( self.pop_fraction * latent_infections_subpop, axis=1 ) numpyro.deterministic("rtu_subpop", rtu_subpop) numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt) - numpyro.deterministic("latent_infections_total", latent_infections_total) + numpyro.deterministic("latent_infections", latent_infections) - return latent_infections_total + return latent_infections class EDVisitObservationProcess(RandomVariable): @@ -342,7 +341,10 @@ def sample( )[-n_observed_disease_ed_visits_datapoints:] latent_ed_visits_final = ( - potential_latent_ed_visits * iedr * ed_wday_effect * population_size + potential_latent_ed_visits + * iedr + * ed_wday_effect + * population_size ) if right_truncation_offset is not None: @@ -359,7 +361,9 @@ def sample( mode="constant", constant_values=(1, 0), ) - latent_ed_visits_now = latent_ed_visits_final * prop_already_reported + latent_ed_visits_now = ( + latent_ed_visits_final * prop_already_reported + ) else: latent_ed_visits_now = latent_ed_visits_final @@ -383,7 +387,9 @@ def __init__( hosp_admit_neg_bin_concentration_rv: RandomVariable, ): self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv - self.hosp_admit_neg_bin_concentration_rv = hosp_admit_neg_bin_concentration_rv + self.hosp_admit_neg_bin_concentration_rv = ( + hosp_admit_neg_bin_concentration_rv + ) def validate(self): pass @@ -513,7 +519,9 @@ def sample( sampled_ed_visits = self.ed_visit_obs_process_rv( latent_infections=latent_infections, population_size=self.population_size, - data_observed_disease_ed_visits=(data_observed_disease_ed_visits), + data_observed_disease_ed_visits=( + data_observed_disease_ed_visits + ), n_observed_disease_ed_visits_datapoints=( n_observed_disease_ed_visits_datapoints ), @@ -606,7 +614,9 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file): "inf_feedback", DistributionalVariable( "inf_feedback_raw", - dist.LogNormal(inf_feedback_prior_logmean, inf_feedback_prior_logsd), + dist.LogNormal( + inf_feedback_prior_logmean, inf_feedback_prior_logsd + ), ), transforms=transformation.AffineTransform(loc=0, scale=-1), ) @@ -638,7 +648,9 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file): "ed_wday_effect", DistributionalVariable( "ed_wday_effect_raw", - dist.Dirichlet(jnp.array(stan_data["hosp_wday_effect_prior_alpha"])), + dist.Dirichlet( + jnp.array(stan_data["hosp_wday_effect_prior_alpha"]) + ), ), transformation.AffineTransform(loc=0, scale=7), ) @@ -713,7 +725,9 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file): my_hosp_admit_obs_model = HospAdmitObservationProcess( inf_to_hosp_admit_rv=inf_to_hosp_admit_rv, - hosp_admit_neg_bin_concentration_rv=(hosp_admit_neg_bin_concentration_rv), + hosp_admit_neg_bin_concentration_rv=( + hosp_admit_neg_bin_concentration_rv + ), ) my_wastewater_obs_model = WastewaterObservationProcess()