Skip to content

Commit

Permalink
Make which quantities are sampled more readily configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Feb 5, 2025
1 parent 0e84d49 commit 6e1c5b6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
21 changes: 16 additions & 5 deletions pipelines/build_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

def build_model_from_dir(
model_dir: Path,
sample_ed_visits: bool = False,
sample_hospital_admission: bool = False,
sample_wastewater: bool = False,
) -> tuple[PyrenewHEWModel, PyrenewHEWData]:
data_path = Path(model_dir) / "data" / "data_for_model_fit.json"
prior_path = Path(model_dir) / "priors.py"
Expand Down Expand Up @@ -47,12 +50,20 @@ def build_model_from_dir(
jnp.array(model_data["generation_interval_pmf"]),
) # check if off by 1 or reversed

data_observed_disease_ed_visits = jnp.array(
model_data["data_observed_disease_ed_visits"]
data_observed_disease_ed_visits = (

Check warning on line 53 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L53

Added line #L53 was not covered by tests
jnp.array(model_data["data_observed_disease_ed_visits"])
if sample_ed_visits
else None
)
data_observed_disease_hospital_admissions = jnp.array(
model_data["data_observed_disease_hospital_admissions"]
data_observed_disease_hospital_admissions = (

Check warning on line 58 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L58

Added line #L58 was not covered by tests
jnp.array(model_data["data_observed_disease_hospital_admissions"])
if sample_hospital_admissions
else None
)

# placeholder
data_observed_disease_wastewater = None if sample_wastewater else None

Check warning on line 65 in pipelines/build_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/build_pyrenew_model.py#L65

Added line #L65 was not covered by tests

population_size = jnp.array(model_data["state_pop"])

ed_right_truncation_pmf_rv = DeterministicVariable(
Expand Down Expand Up @@ -133,7 +144,7 @@ def build_model_from_dir(
data_observed_disease_hospital_admissions=(
data_observed_disease_hospital_admissions
),
data_observed_disease_wastewater=None, # placeholder
data_observed_disease_wastewater=data_observed_disease_wastewater,
right_truncation_offset=right_truncation_offset,
first_ed_visits_date=first_ed_visits_date,
first_hospital_admissions_date=first_hospital_admissions_date,
Expand Down
16 changes: 12 additions & 4 deletions pipelines/fit_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
def fit_and_save_model(
model_run_dir: str,
model_name: str,
sample_ed_visits: bool = False,
sample_hospital_admissions: bool = False,
sample_wastewater: bool = False,
n_warmup: int = 1000,
n_samples: int = 1000,
n_chains: int = 4,
Expand All @@ -26,12 +29,17 @@ def fit_and_save_model(
"rng_key must be an integer with which "
"to seed :func:`jax.random.key`"
)
(my_model, my_data) = build_model_from_dir(model_run_dir)
(my_model, my_data) = build_model_from_dir(

Check warning on line 32 in pipelines/fit_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/fit_pyrenew_model.py#L32

Added line #L32 was not covered by tests
model_run_dir,
sample_ed_visits=sample_ed_visits,
sample_hospital_admissions=sample_hospital_admissions,
sample_wastewater=sample_wastewater,
)
my_model.run(
data=my_data,
sample_ed_visits=True,
sample_hospital_admissions=True,
sample_wastewater=False,
sample_ed_visits=sample_ed_visits,
sample_hospital_admissions=sample_hospital_admissions,
sample_wastewater=sample_wastewater,
num_warmup=n_warmup,
num_samples=n_samples,
rng_key=rng_key,
Expand Down
1 change: 1 addition & 0 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def main(
n_warmup=n_warmup,
n_samples=n_samples,
n_chains=n_chains,
sample_ed_visits=True,
)
logger.info("Model fitting complete")

Expand Down

0 comments on commit 6e1c5b6

Please sign in to comment.