Skip to content

Commit

Permalink
add subpopulation to LatentInfectionProcess (#282)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari authored Jan 17, 2025
1 parent a70e296 commit bc76d59
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 72 deletions.
171 changes: 152 additions & 19 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpyro.distributions as dist
import pyrenew.transformation as transformation
from jax.typing import ArrayLike
from numpyro.infer.reparam import LocScaleReparam
from pyrenew.arrayutils import repeat_until_n, tile_until_n
from pyrenew.convolve import compute_delay_ascertained_incidence
from pyrenew.deterministic import DeterministicVariable
Expand Down Expand Up @@ -34,15 +35,15 @@ def __init__(
infection_feedback_strength_rv: RandomVariable,
infection_feedback_pmf_rv: RandomVariable,
n_initialization_points: int,
pop_fraction: ArrayLike = jnp.array([1]),
autoreg_rt_subpop_rv: RandomVariable = None,
sigma_rt_rv: RandomVariable = None,
sigma_i_first_obs_rv: RandomVariable = None,
sigma_initial_exp_growth_rate_rv: RandomVariable = None,
offset_ref_logit_i_first_obs_rv: RandomVariable = None,
offset_ref_initial_exp_growth_rate_rv: RandomVariable = None,
offset_ref_log_rt_rv: RandomVariable = None,
) -> None:
self.infection_initialization_process = InfectionInitializationProcess(
"I0_initialization",
i0_first_obs_n_rv,
InitializeInfectionsExponentialGrowth(
n_initialization_points, initialization_rate_rv, t_pre_init=0
),
)

self.inf_with_feedback_proc = InfectionsWithFeedback(
infection_feedback_strength=infection_feedback_strength_rv,
infection_feedback_pmf=infection_feedback_pmf_rv,
Expand All @@ -53,11 +54,27 @@ def __init__(
differencing_order=1,
)

self.log_r_mu_intercept_rv = log_r_mu_intercept_rv
self.autoreg_rt_rv = autoreg_rt_rv
self.eta_sd_rv = eta_sd_rv
self.log_r_mu_intercept_rv = log_r_mu_intercept_rv
self.generation_interval_pmf_rv = generation_interval_pmf_rv
self.infection_feedback_pmf_rv = infection_feedback_pmf_rv
self.i0_first_obs_n_rv = i0_first_obs_n_rv
self.initialization_rate_rv = initialization_rate_rv
self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv
self.offset_ref_initial_exp_growth_rate_rv = (
offset_ref_initial_exp_growth_rate_rv
)
self.offset_ref_log_rt_rv = offset_ref_log_rt_rv
self.autoreg_rt_subpop_rv = autoreg_rt_subpop_rv
self.sigma_rt_rv = sigma_rt_rv
self.sigma_i_first_obs_rv = sigma_i_first_obs_rv
self.sigma_initial_exp_growth_rate_rv = (
sigma_initial_exp_growth_rate_rv
)
self.n_initialization_points = n_initialization_points
self.pop_fraction = pop_fraction
self.n_subpops = len(pop_fraction)

def validate(self):
pass
Expand All @@ -66,8 +83,6 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int):
"""
Sample latent infections.
"""
i0 = self.infection_initialization_process()

eta_sd = self.eta_sd_rv()
autoreg_rt = self.autoreg_rt_rv()
log_r_mu_intercept = self.log_r_mu_intercept_rv()
Expand All @@ -85,28 +100,146 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int):
noise_name="rtu_weekly_diff_first_diff_ar_process_noise",
)

rtu = repeat_until_n(
data=jnp.exp(log_rtu_weekly),
n_timepoints=n_days_post_init,
offset=0,
period_size=7,
i0_first_obs_n = self.i0_first_obs_n_rv()
initial_exp_growth_rate = self.initialization_rate_rv()
if self.n_subpops == 1:
i_first_obs_over_n_subpop = i0_first_obs_n
initial_exp_growth_rate_subpop = initial_exp_growth_rate
log_rtu_weekly_subpop = log_rtu_weekly[:, jnp.newaxis]
else:
i_first_obs_over_n_ref_subpop = transformation.SigmoidTransform()(
transformation.logit(i0_first_obs_n)
+ self.offset_ref_logit_i_first_obs_rv(),
)
initial_exp_growth_rate_ref_subpop = (
initial_exp_growth_rate
+ self.offset_ref_initial_exp_growth_rate_rv()
)

log_rtu_weekly_ref_subpop = (
log_rtu_weekly + self.offset_ref_log_rt_rv()
)
i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable(
"i_first_obs_over_n_non_ref_subpop",
DistributionalVariable(
"i_first_obs_over_n_non_ref_subpop_raw",
dist.Normal(
transformation.logit(i0_first_obs_n),
self.sigma_i_first_obs_rv(),
),
reparam=LocScaleReparam(0),
),
transforms=transformation.SigmoidTransform(),
)
initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable(
"initial_exp_growth_rate_non_ref_subpop_raw",
dist.Normal(
initial_exp_growth_rate,
self.sigma_initial_exp_growth_rate_rv(),
),
reparam=LocScaleReparam(0),
)

autoreg_rt_subpop = self.autoreg_rt_subpop_rv()
sigma_rt = self.sigma_rt_rv()
rtu_subpop_ar_init_rv = DistributionalVariable(
"rtu_subpop_ar_init",
dist.Normal(
0,
sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_subpop, 2)),
),
)

with numpyro.plate("n_subpops", self.n_subpops - 1):
initial_exp_growth_rate_non_ref_subpop = (
initial_exp_growth_rate_non_ref_subpop_rv()
)
i_first_obs_over_n_non_ref_subpop = (
i_first_obs_over_n_non_ref_subpop_rv()
)
rtu_subpop_ar_init = rtu_subpop_ar_init_rv()

i_first_obs_over_n_subpop = jnp.hstack(
[
i_first_obs_over_n_ref_subpop,
i_first_obs_over_n_non_ref_subpop,
]
)
initial_exp_growth_rate_subpop = jnp.hstack(
[
initial_exp_growth_rate_ref_subpop,
initial_exp_growth_rate_non_ref_subpop,
]
)

rtu_subpop_ar_proc = ARProcess()
rtu_subpop_ar_weekly = rtu_subpop_ar_proc(
noise_name="rtu_ar_proc",
n=n_weeks_post_init,
init_vals=rtu_subpop_ar_init[jnp.newaxis],
autoreg=autoreg_rt_subpop[jnp.newaxis],
noise_sd=sigma_rt,
)

log_rtu_weekly_non_ref_subpop = (
rtu_subpop_ar_weekly + log_rtu_weekly[:, jnp.newaxis]
)
log_rtu_weekly_subpop = jnp.concat(
[
log_rtu_weekly_ref_subpop[:, jnp.newaxis],
log_rtu_weekly_non_ref_subpop,
],
axis=1,
)

rtu_subpop = jnp.squeeze(
jnp.repeat(
jnp.exp(log_rtu_weekly_subpop),
repeats=7,
axis=0,
)[:n_days_post_init, :]
)

i0_subpop_rv = DeterministicVariable(
"i0_subpop", i_first_obs_over_n_subpop
)
initial_exp_growth_rate_subpop_rv = DeterministicVariable(
"initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop
)
infection_initialization_process = InfectionInitializationProcess(
"I0_initialization",
i0_subpop_rv,
InitializeInfectionsExponentialGrowth(
self.n_initialization_points,
initial_exp_growth_rate_subpop_rv,
t_pre_init=0,
),
)

generation_interval_pmf = self.generation_interval_pmf_rv()
i0 = infection_initialization_process()

inf_with_feedback_proc_sample = self.inf_with_feedback_proc(
Rt=rtu,
Rt=rtu_subpop,
I0=i0,
gen_int=generation_interval_pmf,
)

latent_infections = jnp.concat(
latent_infections_subpop = jnp.concat(
[
i0,
inf_with_feedback_proc_sample.post_initialization_infections,
]
)
numpyro.deterministic("rtu", rtu)

if self.n_subpops == 1:
latent_infections = latent_infections_subpop
else:
latent_infections = jnp.sum(
self.pop_fraction * latent_infections_subpop, axis=1
)

numpyro.deterministic("rtu_subpop", rtu_subpop)
numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt)
numpyro.deterministic("latent_infections", latent_infections)

Expand Down
Loading

0 comments on commit bc76d59

Please sign in to comment.