Skip to content

Commit

Permalink
reorg
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari committed Jan 10, 2025
1 parent 6aa024b commit 50dc084
Showing 1 changed file with 69 additions and 55 deletions.
124 changes: 69 additions & 55 deletions pyrenew_hew/pyrenew_hew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.transforms as transforms

Check warning on line 7 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L7

Added line #L7 was not covered by tests
from numpyro.infer.reparam import LocScaleReparam
import pyrenew.transformation as transformation
from jax.typing import ArrayLike
from numpyro.infer.reparam import LocScaleReparam

Check warning on line 10 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L10

Added line #L10 was not covered by tests
from pyrenew.arrayutils import repeat_until_n, tile_until_n
from pyrenew.convolve import compute_delay_ascertained_incidence
from pyrenew.deterministic import DeterministicVariable
Expand All @@ -27,26 +27,25 @@
class LatentInfectionProcess(RandomVariable):
def __init__(
self,
i0_first_obs_n_rv: RandomVariable,
initialization_rate_rv: RandomVariable,
log_r_mu_intercept_rv: RandomVariable,
autoreg_rt_rv: RandomVariable, # ar coefficient of AR(1) process on R'(t)
eta_sd_rv: RandomVariable, # sd of random walk for ar process on R'(t)
generation_interval_pmf_rv: RandomVariable,
infection_feedback_strength_rv: RandomVariable,
infection_feedback_pmf_rv: RandomVariable,
i_first_obs_over_n_rv: RandomVariable,
mean_initial_exp_growth_rate_rv: RandomVariable,
offset_ref_logit_i_first_obs_rv: RandomVariable,
offset_ref_initial_exp_growth_rate_rv: RandomVariable,
offset_ref_log_r_t_rv: RandomVariable,
n_initialization_points: int,
pop_fraction: float = 1,
n_subpops: int = 1,
autoreg_rt_subpop_rv: RandomVariable = None,
sigma_rt_rv: RandomVariable = None,
sigma_i_first_obs_rv: RandomVariable = None,
sigma_initial_exp_growth_rate_rv: RandomVariable = None,
offset_ref_logit_i_first_obs_rv: RandomVariable = None,
offset_ref_initial_exp_growth_rate_rv: RandomVariable = None,
offset_ref_log_rt_rv: RandomVariable = None,
) -> None:

self.inf_with_feedback_proc = InfectionsWithFeedback(
infection_feedback_strength=infection_feedback_strength_rv,
infection_feedback_pmf=infection_feedback_pmf_rv,
Expand All @@ -62,18 +61,19 @@ def __init__(
self.eta_sd_rv = eta_sd_rv
self.generation_interval_pmf_rv = generation_interval_pmf_rv
self.infection_feedback_pmf_rv = infection_feedback_pmf_rv

self.i_first_obs_over_n_rv = i_first_obs_over_n_rv
self.mean_initial_exp_growth_rate_rv = mean_initial_exp_growth_rate_rv
self.i0_first_obs_n_rv = i0_first_obs_n_rv
self.initialization_rate_rv = initialization_rate_rv
self.offset_ref_logit_i_first_obs_rv = offset_ref_logit_i_first_obs_rv
self.offset_ref_initial_exp_growth_rate_rv = (

Check warning on line 67 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L64-L67

Added lines #L64 - L67 were not covered by tests
offset_ref_initial_exp_growth_rate_rv
)
self.offset_ref_log_r_t_rv = offset_ref_log_r_t_rv
self.offset_ref_log_rt_rv = offset_ref_log_rt_rv
self.autoreg_rt_subpop_rv = autoreg_rt_subpop_rv
self.sigma_rt_rv = sigma_rt_rv
self.sigma_i_first_obs_rv = sigma_i_first_obs_rv
self.sigma_initial_exp_growth_rate_rv = sigma_initial_exp_growth_rate_rv
self.sigma_initial_exp_growth_rate_rv = (

Check warning on line 74 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L70-L74

Added lines #L70 - L74 were not covered by tests
sigma_initial_exp_growth_rate_rv
)
self.n_initialization_points = n_initialization_points
self.pop_fraction = pop_fraction
self.n_subpops = n_subpops

Check warning on line 79 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L77-L79

Added lines #L77 - L79 were not covered by tests
Expand Down Expand Up @@ -102,50 +102,47 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int):
noise_name="rtu_weekly_diff_first_diff_ar_process_noise",
)

i_first_obs_over_n = self.i_first_obs_over_n_rv()
offset_ref_logit_i_first_obs = self.offset_ref_logit_i_first_obs_rv()

mean_initial_exp_growth_rate = self.mean_initial_exp_growth_rate_rv()
offset_ref_initial_exp_growth_rate = (
self.offset_ref_initial_exp_growth_rate_rv()
)

i_first_obs_over_n_ref_subpop = transforms.SigmoidTransform()(
transforms.logit(i_first_obs_over_n)
+ jnp.where(self.n_subpops > 1, offset_ref_logit_i_first_obs, 0)
)
initial_exp_growth_rate_ref_subpop = mean_initial_exp_growth_rate + jnp.where(
self.n_subpops > 1, offset_ref_initial_exp_growth_rate, 0
)

offset_ref_log_r_t = self.offset_ref_log_r_t_rv()
log_rtu_ref_subpop_in_week = log_rtu_weekly + jnp.where(
self.n_subpops > 1, offset_ref_log_r_t, 0
)

if self.n_subpops == 1:
i_first_obs_over_n_subpop = i_first_obs_over_n_ref_subpop
initial_exp_growth_rate_subpop = initial_exp_growth_rate_ref_subpop
log_rtu_weekly_subpop = log_rtu_ref_subpop_in_week[:, jnp.newaxis]
i_first_obs_over_n_subpop = self.i0_first_obs_n_rv()
initial_exp_growth_rate_subpop = self.initialization_rate_rv()
log_rtu_weekly_subpop = log_rtu_weekly[:, jnp.newaxis]

Check warning on line 108 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L105-L108

Added lines #L105 - L108 were not covered by tests
else:
sigma_i_first_obs = self.sigma_i_first_obs_rv()
i_first_obs_over_n_ref_subpop = transforms.SigmoidTransform()(

Check warning on line 110 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L110

Added line #L110 was not covered by tests
transforms.logit(self.i0_first_obs_n_rv())
+ jnp.where(
self.n_subpops > 1,
self.offset_ref_logit_i_first_obs_rv(),
0,
)
)
initial_exp_growth_rate_ref_subpop = (

Check warning on line 118 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L118

Added line #L118 was not covered by tests
self.initialization_rate_rv()
+ jnp.where(
self.n_subpops > 1,
self.offset_ref_initial_exp_growth_rate_rv(),
0,
)
)
log_rtu_weekly_ref_subpop = log_rtu_weekly + jnp.where(

Check warning on line 126 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L126

Added line #L126 was not covered by tests
self.n_subpops > 1, self.offset_ref_log_rt_rv(), 0
)
i_first_obs_over_n_non_ref_subpop_rv = TransformedVariable(

Check warning on line 129 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L129

Added line #L129 was not covered by tests
"i_first_obs_over_n_non_ref_subpop",
DistributionalVariable(
"i_first_obs_over_n_non_ref_subpop_raw",
dist.Normal(
transforms.logit(i_first_obs_over_n), sigma_i_first_obs
transforms.logit(self.i0_first_obs_n_rv()),
self.sigma_i_first_obs_rv(),
),
reparam=LocScaleReparam(0),
),
transforms=transforms.SigmoidTransform(),
)
sigma_initial_exp_growth_rate = self.sigma_initial_exp_growth_rate_rv()
initial_exp_growth_rate_non_ref_subpop_rv = DistributionalVariable(

Check warning on line 141 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L141

Added line #L141 was not covered by tests
"initial_exp_growth_rate_non_ref_subpop_raw",
dist.Normal(
mean_initial_exp_growth_rate,
sigma_initial_exp_growth_rate,
self.initialization_rate_rv(),
self.sigma_initial_exp_growth_rate_rv(),
),
reparam=LocScaleReparam(0),
)
Expand Down Expand Up @@ -191,13 +188,13 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int):
noise_sd=sigma_rt,
)

log_rtu_non_ref_subpop_in_week = (
log_rtu_weekly_non_ref_subpop = (

Check warning on line 191 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L191

Added line #L191 was not covered by tests
rtu_subpop_ar_weekly + log_rtu_weekly[:, jnp.newaxis]
)
log_rtu_weekly_subpop = jnp.concat(

Check warning on line 194 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L194

Added line #L194 was not covered by tests
[
log_rtu_ref_subpop_in_week[:, jnp.newaxis],
log_rtu_non_ref_subpop_in_week,
log_rtu_weekly_ref_subpop[:, jnp.newaxis],
log_rtu_weekly_non_ref_subpop,
],
axis=1,
)
Expand All @@ -214,7 +211,9 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int):
jnp.log(i_first_obs_over_n_subpop)
- self.unobs_time * initial_exp_growth_rate_subpop
)
i0_subpop_rv = DeterministicVariable("i0_subpop", jnp.exp(log_i0_subpop))
i0_subpop_rv = DeterministicVariable(

Check warning on line 214 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L214

Added line #L214 was not covered by tests
"i0_subpop", jnp.exp(log_i0_subpop)
)
initial_exp_growth_rate_subpop_rv = DeterministicVariable(

Check warning on line 217 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L217

Added line #L217 was not covered by tests
"initial_exp_growth_rate_subpop", initial_exp_growth_rate_subpop
)
Expand Down Expand Up @@ -248,15 +247,15 @@ def sample(self, n_days_post_init: int, n_weeks_post_init: int):
if latent_infections_subpop.shape[0] == 1:
latent_infections_subpop = latent_infections_subpop.T

Check warning on line 248 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L247-L248

Added lines #L247 - L248 were not covered by tests

latent_infections_total = jnp.sum(
latent_infections = jnp.sum(

Check warning on line 250 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L250

Added line #L250 was not covered by tests
self.pop_fraction * latent_infections_subpop, axis=1
)

numpyro.deterministic("rtu_subpop", rtu_subpop)

Check warning on line 254 in pyrenew_hew/pyrenew_hew_model.py

View check run for this annotation

Codecov / codecov/patch

pyrenew_hew/pyrenew_hew_model.py#L254

Added line #L254 was not covered by tests
numpyro.deterministic("rt", inf_with_feedback_proc_sample.rt)
numpyro.deterministic("latent_infections_total", latent_infections_total)
numpyro.deterministic("latent_infections", latent_infections)

return latent_infections_total
return latent_infections


class EDVisitObservationProcess(RandomVariable):
Expand Down Expand Up @@ -342,7 +341,10 @@ def sample(
)[-n_observed_disease_ed_visits_datapoints:]

latent_ed_visits_final = (
potential_latent_ed_visits * iedr * ed_wday_effect * population_size
potential_latent_ed_visits
* iedr
* ed_wday_effect
* population_size
)

if right_truncation_offset is not None:
Expand All @@ -359,7 +361,9 @@ def sample(
mode="constant",
constant_values=(1, 0),
)
latent_ed_visits_now = latent_ed_visits_final * prop_already_reported
latent_ed_visits_now = (
latent_ed_visits_final * prop_already_reported
)
else:
latent_ed_visits_now = latent_ed_visits_final

Expand All @@ -383,7 +387,9 @@ def __init__(
hosp_admit_neg_bin_concentration_rv: RandomVariable,
):
self.inf_to_hosp_admit_rv = inf_to_hosp_admit_rv
self.hosp_admit_neg_bin_concentration_rv = hosp_admit_neg_bin_concentration_rv
self.hosp_admit_neg_bin_concentration_rv = (
hosp_admit_neg_bin_concentration_rv
)

def validate(self):
pass
Expand Down Expand Up @@ -513,7 +519,9 @@ def sample(
sampled_ed_visits = self.ed_visit_obs_process_rv(
latent_infections=latent_infections,
population_size=self.population_size,
data_observed_disease_ed_visits=(data_observed_disease_ed_visits),
data_observed_disease_ed_visits=(
data_observed_disease_ed_visits
),
n_observed_disease_ed_visits_datapoints=(
n_observed_disease_ed_visits_datapoints
),
Expand Down Expand Up @@ -606,7 +614,9 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file):
"inf_feedback",
DistributionalVariable(
"inf_feedback_raw",
dist.LogNormal(inf_feedback_prior_logmean, inf_feedback_prior_logsd),
dist.LogNormal(
inf_feedback_prior_logmean, inf_feedback_prior_logsd
),
),
transforms=transformation.AffineTransform(loc=0, scale=-1),
)
Expand Down Expand Up @@ -638,7 +648,9 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file):
"ed_wday_effect",
DistributionalVariable(
"ed_wday_effect_raw",
dist.Dirichlet(jnp.array(stan_data["hosp_wday_effect_prior_alpha"])),
dist.Dirichlet(
jnp.array(stan_data["hosp_wday_effect_prior_alpha"])
),
),
transformation.AffineTransform(loc=0, scale=7),
)
Expand Down Expand Up @@ -713,7 +725,9 @@ def create_pyrenew_hew_model_from_stan_data(stan_data_file):

my_hosp_admit_obs_model = HospAdmitObservationProcess(
inf_to_hosp_admit_rv=inf_to_hosp_admit_rv,
hosp_admit_neg_bin_concentration_rv=(hosp_admit_neg_bin_concentration_rv),
hosp_admit_neg_bin_concentration_rv=(
hosp_admit_neg_bin_concentration_rv
),
)

my_wastewater_obs_model = WastewaterObservationProcess()
Expand Down

0 comments on commit 50dc084

Please sign in to comment.