From 6e1c5b653b391e529ff5856254828868466bd415 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Wed, 5 Feb 2025 16:47:49 -0500 Subject: [PATCH] Make which quantities are sampled more readily configurable --- pipelines/build_pyrenew_model.py | 21 ++++++++++++++++----- pipelines/fit_pyrenew_model.py | 16 ++++++++++++---- pipelines/forecast_state.py | 1 + 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 0019e29f..8d35071b 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -18,6 +18,9 @@ def build_model_from_dir( model_dir: Path, + sample_ed_visits: bool = False, + sample_hospital_admission: bool = False, + sample_wastewater: bool = False, ) -> tuple[PyrenewHEWModel, PyrenewHEWData]: data_path = Path(model_dir) / "data" / "data_for_model_fit.json" prior_path = Path(model_dir) / "priors.py" @@ -47,12 +50,20 @@ def build_model_from_dir( jnp.array(model_data["generation_interval_pmf"]), ) # check if off by 1 or reversed - data_observed_disease_ed_visits = jnp.array( - model_data["data_observed_disease_ed_visits"] + data_observed_disease_ed_visits = ( + jnp.array(model_data["data_observed_disease_ed_visits"]) + if sample_ed_visits + else None ) - data_observed_disease_hospital_admissions = jnp.array( - model_data["data_observed_disease_hospital_admissions"] + data_observed_disease_hospital_admissions = ( + jnp.array(model_data["data_observed_disease_hospital_admissions"]) + if sample_hospital_admissions + else None ) + + # placeholder + data_observed_disease_wastewater = None if sample_wastewater else None + population_size = jnp.array(model_data["state_pop"]) ed_right_truncation_pmf_rv = DeterministicVariable( @@ -133,7 +144,7 @@ def build_model_from_dir( data_observed_disease_hospital_admissions=( data_observed_disease_hospital_admissions ), - data_observed_disease_wastewater=None, # placeholder + data_observed_disease_wastewater=data_observed_disease_wastewater, right_truncation_offset=right_truncation_offset, first_ed_visits_date=first_ed_visits_date, first_hospital_admissions_date=first_hospital_admissions_date, diff --git a/pipelines/fit_pyrenew_model.py b/pipelines/fit_pyrenew_model.py index 33fb86b9..5c06a107 100644 --- a/pipelines/fit_pyrenew_model.py +++ b/pipelines/fit_pyrenew_model.py @@ -12,6 +12,9 @@ def fit_and_save_model( model_run_dir: str, model_name: str, + sample_ed_visits: bool = False, + sample_hospital_admissions: bool = False, + sample_wastewater: bool = False, n_warmup: int = 1000, n_samples: int = 1000, n_chains: int = 4, @@ -26,12 +29,17 @@ def fit_and_save_model( "rng_key must be an integer with which " "to seed :func:`jax.random.key`" ) - (my_model, my_data) = build_model_from_dir(model_run_dir) + (my_model, my_data) = build_model_from_dir( + model_run_dir, + sample_ed_visits=sample_ed_visits, + sample_hospital_admissions=sample_hospital_admissions, + sample_wastewater=sample_wastewater, + ) my_model.run( data=my_data, - sample_ed_visits=True, - sample_hospital_admissions=True, - sample_wastewater=False, + sample_ed_visits=sample_ed_visits, + sample_hospital_admissions=sample_hospital_admissions, + sample_wastewater=sample_wastewater, num_warmup=n_warmup, num_samples=n_samples, rng_key=rng_key, diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index a62efada..50cc8d30 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -337,6 +337,7 @@ def main( n_warmup=n_warmup, n_samples=n_samples, n_chains=n_chains, + sample_ed_visits=True, ) logger.info("Model fitting complete")