Skip to content

Commit

Permalink
Modularize model in preparation for fitting real NHSN data (#277)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris authored Jan 7, 2025
1 parent 1fb4c2e commit 5d790b7
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 105 deletions.
46 changes: 38 additions & 8 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import jax.numpy as jnp
from pyrenew.deterministic import DeterministicVariable

from pyrenew_hew.pyrenew_hew_model import pyrenew_hew_model
from pyrenew_hew.pyrenew_hew_model import (
EDVisitObservationProcess,
HospAdmitObservationProcess,
LatentInfectionProcess,
PyrenewHEWModel,
WastewaterObservationProcess,
)


def build_model_from_dir(model_dir):
Expand All @@ -21,6 +27,11 @@ def build_model_from_dir(model_dir):
"inf_to_ed", jnp.array(model_data["inf_to_ed_pmf"])
) # check if off by 1 or reversed

# use same as inf to ed, per NNH guidelines
inf_to_hosp_admit_rv = DeterministicVariable(
"inf_to_hosp_admit", jnp.array(model_data["inf_to_ed_pmf"])
) # check if off by 1 or reversed

generation_interval_pmf_rv = DeterministicVariable(
"generation_interval_pmf",
jnp.array(model_data["generation_interval_pmf"]),
Expand All @@ -37,9 +48,9 @@ def build_model_from_dir(model_dir):
data_observed_disease_hospital_admissions = jnp.array(
model_data["data_observed_disease_hospital_admissions"]
)
state_pop = jnp.array(model_data["state_pop"])
population_size = jnp.array(model_data["state_pop"])

right_truncation_pmf_rv = DeterministicVariable(
ed_right_truncation_pmf_rv = DeterministicVariable(
"right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"])
)

Expand All @@ -55,8 +66,7 @@ def build_model_from_dir(model_dir):

right_truncation_offset = model_data["right_truncation_offset"]

my_model = pyrenew_hew_model(
state_pop=state_pop,
my_latent_infection_model = LatentInfectionProcess(
i0_first_obs_n_rv=priors["i0_first_obs_n_rv"],
initialization_rate_rv=priors["initialization_rate_rv"],
log_r_mu_intercept_rv=priors["log_r_mu_intercept_rv"],
Expand All @@ -65,14 +75,34 @@ def build_model_from_dir(model_dir):
generation_interval_pmf_rv=generation_interval_pmf_rv,
infection_feedback_strength_rv=priors["inf_feedback_strength_rv"],
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
n_initialization_points=uot,
)

my_ed_visit_obs_model = EDVisitObservationProcess(
p_ed_mean_rv=priors["p_ed_visit_mean_rv"],
p_ed_w_sd_rv=priors["p_ed_visit_w_sd_rv"],
autoreg_p_ed_rv=priors["autoreg_p_ed_visit_rv"],
ed_wday_effect_rv=priors["ed_visit_wday_effect_rv"],
inf_to_ed_rv=inf_to_ed_rv,
phi_rv=priors["phi_rv"],
right_truncation_pmf_rv=right_truncation_pmf_rv,
n_initialization_points=uot,
ed_neg_bin_concentration_rv=(priors["ed_neg_bin_concentration_rv"]),
ed_right_truncation_pmf_rv=ed_right_truncation_pmf_rv,
)

my_hosp_admit_obs_model = HospAdmitObservationProcess(
inf_to_hosp_admit_rv=inf_to_hosp_admit_rv,
hosp_admit_neg_bin_concentration_rv=(
priors["hosp_admit_neg_bin_concentration_rv"]
),
)

my_wastewater_obs_model = WastewaterObservationProcess()

my_model = PyrenewHEWModel(
population_size=population_size,
latent_infection_process_rv=my_latent_infection_model,
ed_visit_obs_process_rv=my_ed_visit_obs_model,
hosp_admit_obs_process_rv=my_hosp_admit_obs_model,
wastewater_obs_process_rv=my_wastewater_obs_model,
)

return (
Expand Down
9 changes: 8 additions & 1 deletion pipelines/priors/prod_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,11 @@
)

# Based on looking at some historical posteriors.
phi_rv = DistributionalVariable("phi", dist.LogNormal(4, 1))
ed_neg_bin_concentration_rv = DistributionalVariable(
"ed_visit_neg_bin_concentration", dist.LogNormal(4, 1)
)

# more diffuse than ED visit, same mean
hosp_admit_neg_bin_concentration_rv = DistributionalVariable(
"hosp_admit_neg_bin_concentration", dist.LogNormal(4, 2)
)
Loading

0 comments on commit 5d790b7

Please sign in to comment.