Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add subpopulation to LatentInfectionProcess #282

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 152 additions & 19 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpyro.distributions as dist
import pyrenew.transformation as transformation
from jax.typing import ArrayLike
from numpyro.infer.reparam import LocScaleReparam
from pyrenew.arrayutils import repeat_until_n, tile_until_n
from pyrenew.convolve import compute_delay_ascertained_incidence
from pyrenew.deterministic import DeterministicVariable
Expand Down Expand Up @@ -34,15 +35,15 @@
infection_feedback_strength_rv: RandomVariable,
infection_feedback_pmf_rv: RandomVariable,
n_initialization_points: int,
pop_fraction: ArrayLike = jnp.array([1]),
autoreg_rt_subpop_rv: RandomVariable = None,
sigma_rt_rv: RandomVariable = None,
sigma_i_first_obs_rv: RandomVariable = None,
sigma_initial_exp_growth_rate_rv: RandomVariable = None,
offset_ref_logit_i_first_obs_rv: RandomVariable = None,
offset_ref_initial_exp_growth_rate_rv: RandomVariable = None,
offset_ref_log_rt_rv: RandomVariable = None,
) -> None:
self.infection_initialization_process = InfectionInitializationProcess(
"I0_initialization",
i0_first_obs_n_rv,
InitializeInfectionsExponentialGrowth(
n_initialization_points, initialization_rate_rv, t_pre_init=0
),
)

self.inf_with_feedback_proc = InfectionsWithFeedback(
infection_feedback_strength=infection_feedback_strength_rv,
infection_feedback_pmf=infection_feedback_pmf_rv,
Expand All @@ -53,11 +54,27 @@
differencing_order=1,
)

self.log_r_mu_intercept_rv = log_r_mu_intercept_rv
self.autoreg_rt_rv = autoreg_rt_rv
self.eta_sd_rv = eta_sd_rv
self.log_r_mu_intercept_rv = log_r_mu_intercept_rv
self.generation_interval_pmf_rv = generation_interval_pmf_rv
self.infection_feedback_pmf_rv = infection_feedback_pmf_rv
self.i0_first_obs_n_rv = i0_first_obs_n_rv
self.initialization_rate_rv = initialization_rate_rv
self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv
self.offset_ref_initial_exp_growth_rate_rv = (
offset_ref_initial_exp_growth_rate_rv
)
self.offset_ref_log_rt_rv = offset_ref_log_rt_rv
self.autoreg_rt_subpop_rv = autoreg_rt_subpop_rv
self.sigma_rt_rv = sigma_rt_rv
self.sigma_i_first_obs_rv = sigma_i_first_obs_rv
self.sigma_initial_exp_growth_rate_rv = (
sigma_initial_exp_growth_rate_rv
)
self.n_initialization_points = n_initialization_points
self.pop_fraction = pop_fraction
self.n_subpops = len(pop_fraction)

def validate(self):
pass
Expand All @@ -66,8 +83,6 @@
"""
Sample latent infections.
"""
i0 = self.infection_initialization_process()

eta_sd = self.eta_sd_rv()
autoreg_rt = self.autoreg_rt_rv()
log_r_mu_intercept = self.log_r_mu_intercept_rv()
Expand All @@ -85,28 +100,146 @@
noise_name="rtu_weekly_diff_first_diff_ar_process_noise",
)

rtu = repeat_until_n(
data=jnp.exp(log_rtu_weekly),
n_timepoints=n_days_post_init,
offset=0,
period_size=7,
i0_first_obs_n = self.i0_first_obs_n_rv()
initial_exp_growth_rate = self.initialization_rate_rv()
if self.n_subpops == 1:
i_first_obs_over_n_subpop = i0_first_obs_n
initial_exp_growth_rate_subpop = initial_exp_growth_rate
log_rtu_weekly_subpop = log_rtu_weekly[:, jnp.newaxis]
else:
i_first_obs_over_n_ref_subpop = transformation.SigmoidTransform()(

Check warning on line 110 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L110

Added line #L110 was not covered by tests
transformation.logit(i0_first_obs_n)
+ self.offset_ref_logit_i_first_obs_rv(),
)
initial_exp_growth_rate_ref_subpop = (

Check warning on line 114 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L114

Added line #L114 was not covered by tests
initial_exp_growth_rate
+ self.offset_ref_initial_exp_growth_rate_rv()
)

log_rtu_weekly_ref_subpop = (

Check warning on line 119 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L119

Added line #L119 was not covered by tests
log_rtu_weekly + self.offset_ref_log_rt_rv()
)
i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable(

Check warning on line 122 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L122

Added line #L122 was not covered by tests
"i_first_obs_over_n_non_ref_subpop",
DistributionalVariable(
"i_first_obs_over_n_non_ref_subpop_raw",
dist.Normal(
transformation.logit(i0_first_obs_n),
self.sigma_i_first_obs_rv(),
),
reparam=LocScaleReparam(0),
),
transforms=transformation.SigmoidTransform(),
)
initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable(

Check warning on line 134 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L134

Added line #L134 was not covered by tests
"initial_exp_growth_rate_non_ref_subpop_raw",
dist.Normal(
initial_exp_growth_rate,
self.sigma_initial_exp_growth_rate_rv(),
),
reparam=LocScaleReparam(0),
)

autoreg_rt_subpop = self.autoreg_rt_subpop_rv()
sigma_rt = self.sigma_rt_rv()
rtu_subpop_ar_init_rv = DistributionalVariable(

Check warning on line 145 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L143-L145

Added lines #L143 - L145 were not covered by tests
"rtu_subpop_ar_init",
dist.Normal(
0,
sigma_rt / jnp.sqrt(1 - jnp.pow(autoreg_rt_subpop, 2)),
),
)

with numpyro.plate("n_subpops", self.n_subpops - 1):
initial_exp_growth_rate_non_ref_subpop = (

Check warning on line 154 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L153-L154

Added lines #L153 - L154 were not covered by tests
initial_exp_growth_rate_non_ref_subpop_rv()
)
i_first_obs_over_n_non_ref_subpop = (

Check warning on line 157 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L157

Added line #L157 was not covered by tests
i_first_obs_over_n_non_ref_subpop_rv()
)
rtu_subpop_ar_init = rtu_subpop_ar_init_rv()

Check warning on line 160 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L160

Added line #L160 was not covered by tests

i_first_obs_over_n_subpop = jnp.hstack(

Check warning on line 162 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L162

Added line #L162 was not covered by tests
[
i_first_obs_over_n_ref_subpop,
i_first_obs_over_n_non_ref_subpop,
]
)
initial_exp_growth_rate_subpop = jnp.hstack(

Check warning on line 168 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L168

Added line #L168 was not covered by tests
[
initial_exp_growth_rate_ref_subpop,
initial_exp_growth_rate_non_ref_subpop,
]
)

rtu_subpop_ar_proc = ARProcess()
rtu_subpop_ar_weekly = rtu_subpop_ar_proc(

Check warning on line 176 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L175-L176

Added lines #L175 - L176 were not covered by tests
noise_name="rtu_ar_proc",
n=n_weeks_post_init,
init_vals=rtu_subpop_ar_init[jnp.newaxis],
autoreg=autoreg_rt_subpop[jnp.newaxis],
noise_sd=sigma_rt,
)

log_rtu_weekly_non_ref_subpop = (

Check warning on line 184 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L184

Added line #L184 was not covered by tests
rtu_subpop_ar_weekly + log_rtu_weekly[:, jnp.newaxis]
)
log_rtu_weekly_subpop = jnp.concat(

Check warning on line 187 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L187

Added line #L187 was not covered by tests
[
log_rtu_weekly_ref_subpop[:, jnp.newaxis],
log_rtu_weekly_non_ref_subpop,
],
axis=1,
)

rtu_subpop = jnp.squeeze(
jnp.repeat(
jnp.exp(log_rtu_weekly_subpop),
repeats=7,
axis=0,
)[:n_days_post_init, :]
)

i0_subpop_rv = DeterministicVariable(
"i0_subpop", i_first_obs_over_n_subpop
)
initial_exp_growth_rate_subpop_rv = DeterministicVariable(
"initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop
)
infection_initialization_process = InfectionInitializationProcess(
"I0_initialization",
i0_subpop_rv,
InitializeInfectionsExponentialGrowth(
self.n_initialization_points,
initial_exp_growth_rate_subpop_rv,
t_pre_init=0,
),
)

generation_interval_pmf = self.generation_interval_pmf_rv()
i0 = infection_initialization_process()

inf_with_feedback_proc_sample = self.inf_with_feedback_proc(
Rt=rtu,
Rt=rtu_subpop,
I0=i0,
gen_int=generation_interval_pmf,
)

latent_infections = jnp.concat(
latent_infections_subpop = jnp.concat(
[
i0,
inf_with_feedback_proc_sample.post_initialization_infections,
]
)
numpyro.deterministic("rtu", rtu)

if self.n_subpops == 1:
latent_infections = latent_infections_subpop
else:
latent_infections = jnp.sum(

Check warning on line 238 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L238

Added line #L238 was not covered by tests
self.pop_fraction * latent_infections_subpop, axis=1
)

numpyro.deterministic("rtu_subpop", rtu_subpop)
numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt)
numpyro.deterministic("latent_infections", latent_infections)

Expand Down
Loading
Loading