Skip to content

Commit

Permalink
Add switches to fitting scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Feb 6, 2025
1 parent 557330d commit 8baa0f0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 3 deletions.
24 changes: 24 additions & 0 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def main(
pool_id: str,
diseases: str | list[str],
output_subdir: str | Path = "./",
sample_ed_visits: bool = False,
sample_hospital_admissions: bool = False,
sample_wastewater: bool = False,
container_image_name: str = "pyrenew-hew",
container_image_version: str = "latest",
n_training_days: int = 90,
Expand Down Expand Up @@ -253,6 +256,27 @@ def main(
default="latest",
)


parser.add_argument(
"--sample-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to and predict ED visit data.",
)
parser.add_argument(
"--sample-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=("If provided, fit to and predict hospital admissions data."),
)
parser.add_argument(
"--sample-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to and predict wastewater data.",
)


parser.add_argument(
"--n-training-days",
type=int,
Expand Down
20 changes: 20 additions & 0 deletions pipelines/fit_pyrenew_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@ def fit_and_save_model(
required=True,
help="Name of the model to use for generating predictions.",
)

parser.add_argument(

Check warning on line 79 in pipelines/fit_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/fit_pyrenew_model.py#L79

Added line #L79 was not covered by tests
"--sample-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to ED visit data.",
)
parser.add_argument(

Check warning on line 85 in pipelines/fit_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/fit_pyrenew_model.py#L85

Added line #L85 was not covered by tests
"--sample-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=("If provided, fit to hospital admissions data."),
)
parser.add_argument(

Check warning on line 91 in pipelines/fit_pyrenew_model.py

View check run for this annotation

Codecov / codecov/patch

pipelines/fit_pyrenew_model.py#L91

Added line #L91 was not covered by tests
"--sample-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to wastewater data.",
)

parser.add_argument(
"--n-warmup",
type=int,
Expand Down
33 changes: 30 additions & 3 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@ def main(
exclude_last_n_days: int = 0,
score: bool = False,
eval_data_path: Path = None,
):
sample_ed_visits: bool = False,
sample_hospital_admissions: bool = False,
sample_wastewater: bool = False,
) -> None:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -337,7 +340,9 @@ def main(
n_warmup=n_warmup,
n_samples=n_samples,
n_chains=n_chains,
sample_ed_visits=True,
sample_ed_visits=sample_ed_visits,
sample_hospital_admissions=sample_hospital_admissions,
sample_wastewater=sample_wastewater,
)
logger.info("Model fitting complete")

Expand All @@ -348,7 +353,9 @@ def main(
model_run_dir,
"pyrenew_e",
n_days_past_last_training,
predict_ed_visits=True,
predict_ed_visits=sample_ed_visits,
predict_hospital_admissions=sample_hospital_admissions,
predict_wastewater=sample_wastewater,
)

logger.info(
Expand Down Expand Up @@ -528,6 +535,26 @@ def main(
type=Path,
help=("Path to a parquet file containing compehensive truth data."),
)

parser.add_argument(

Check warning on line 539 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L539

Added line #L539 was not covered by tests
"--sample-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to ED visit data.",
)
parser.add_argument(

Check warning on line 545 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L545

Added line #L545 was not covered by tests
"--sample-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=("If provided, fit to hospital admissions data."),
)
parser.add_argument(

Check warning on line 551 in pipelines/forecast_state.py

View check run for this annotation

Codecov / codecov/patch

pipelines/forecast_state.py#L551

Added line #L551 was not covered by tests
"--sample-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to wastewater data.",
)

args = parser.parse_args()
numpyro.set_host_device_count(args.n_chains)
main(**vars(args))

0 comments on commit 8baa0f0

Please sign in to comment.