Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularize model in preparation for fitting real NHSN data #277

Merged
merged 10 commits into from
Jan 7, 2025
46 changes: 38 additions & 8 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import jax.numpy as jnp
from pyrenew.deterministic import DeterministicVariable

from pyrenew_hew.pyrenew_hew_model import pyrenew_hew_model
from pyrenew_hew.pyrenew_hew_model import (

Check warning on line 7 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L7

Added line #L7 was not covered by tests
EDVisitObservationProcess,
HospAdmitObservationProcess,
LatentInfectionProcess,
PyrenewHEWModel,
WastewaterObservationProcess,
)


def build_model_from_dir(model_dir):
Expand All @@ -21,6 +27,11 @@
"inf_to_ed", jnp.array(model_data["inf_to_ed_pmf"])
) # check if off by 1 or reversed

# use same as inf to ed, per NNH guidelines
inf_to_hosp_admit_rv = DeterministicVariable(

Check warning on line 31 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L31

Added line #L31 was not covered by tests
"inf_to_hosp_admit", jnp.array(model_data["inf_to_ed_pmf"])
) # check if off by 1 or reversed

generation_interval_pmf_rv = DeterministicVariable(
"generation_interval_pmf",
jnp.array(model_data["generation_interval_pmf"]),
Expand All @@ -37,9 +48,9 @@
data_observed_disease_hospital_admissions = jnp.array(
model_data["data_observed_disease_hospital_admissions"]
)
state_pop = jnp.array(model_data["state_pop"])
population_size = jnp.array(model_data["state_pop"])

Check warning on line 51 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L51

Added line #L51 was not covered by tests

right_truncation_pmf_rv = DeterministicVariable(
ed_right_truncation_pmf_rv = DeterministicVariable(

Check warning on line 53 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L53

Added line #L53 was not covered by tests
"right_truncation_pmf", jnp.array(model_data["right_truncation_pmf"])
)

Expand All @@ -55,8 +66,7 @@

right_truncation_offset = model_data["right_truncation_offset"]

my_model = pyrenew_hew_model(
state_pop=state_pop,
my_latent_infection_model = LatentInfectionProcess(

Check warning on line 69 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L69

Added line #L69 was not covered by tests
i0_first_obs_n_rv=priors["i0_first_obs_n_rv"],
initialization_rate_rv=priors["initialization_rate_rv"],
log_r_mu_intercept_rv=priors["log_r_mu_intercept_rv"],
Expand All @@ -65,14 +75,34 @@
generation_interval_pmf_rv=generation_interval_pmf_rv,
infection_feedback_strength_rv=priors["inf_feedback_strength_rv"],
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
n_initialization_points=uot,
)

my_ed_visit_obs_model = EDVisitObservationProcess(

Check warning on line 81 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L81

Added line #L81 was not covered by tests
p_ed_mean_rv=priors["p_ed_visit_mean_rv"],
p_ed_w_sd_rv=priors["p_ed_visit_w_sd_rv"],
autoreg_p_ed_rv=priors["autoreg_p_ed_visit_rv"],
ed_wday_effect_rv=priors["ed_visit_wday_effect_rv"],
inf_to_ed_rv=inf_to_ed_rv,
phi_rv=priors["phi_rv"],
right_truncation_pmf_rv=right_truncation_pmf_rv,
n_initialization_points=uot,
ed_neg_bin_concentration_rv=(priors["ed_neg_bin_concentration_rv"]),
ed_right_truncation_pmf_rv=ed_right_truncation_pmf_rv,
)

my_hosp_admit_obs_model = HospAdmitObservationProcess(

Check warning on line 91 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L91

Added line #L91 was not covered by tests
inf_to_hosp_admit_rv=inf_to_hosp_admit_rv,
hosp_admit_neg_bin_concentration_rv=(
priors["hosp_admit_neg_bin_concentration_rv"]
),
)

my_wastewater_obs_model = WastewaterObservationProcess()

Check warning on line 98 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L98

Added line #L98 was not covered by tests

my_model = PyrenewHEWModel(

Check warning on line 100 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L100

Added line #L100 was not covered by tests
population_size=population_size,
latent_infection_process_rv=my_latent_infection_model,
ed_visit_obs_process_rv=my_ed_visit_obs_model,
hosp_admit_obs_process_rv=my_hosp_admit_obs_model,
wastewater_obs_process_rv=my_wastewater_obs_model,
)

return (
Expand Down
9 changes: 8 additions & 1 deletion pipelines/priors/prod_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,11 @@
)

# Based on looking at some historical posteriors.
phi_rv = DistributionalVariable("phi", dist.LogNormal(4, 1))
ed_neg_bin_concentration_rv = DistributionalVariable(
"ed_visit_neg_bin_concentration", dist.LogNormal(4, 1)
)

# more diffuse than ED visit, same mean
hosp_admit_neg_bin_concentration_rv = DistributionalVariable(
"hosp_admit_neg_bin_concentration", dist.LogNormal(4, 2)
)
Loading
Loading