Skip to content

Commit

Permalink
Make which quantities are sampled more readily configurable (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris authored Feb 13, 2025
1 parent 5bd8178 commit 026f1eb
Show file tree
Hide file tree
Showing 9 changed files with 679 additions and 121 deletions.
299 changes: 199 additions & 100 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import argparse
import itertools
import re
from pathlib import Path

import polars as pl
from azure.batch import models
from azuretools.auth import EnvCredentialHandler
from azuretools.client import get_batch_service_client
Expand All @@ -21,6 +21,12 @@ def main(
pool_id: str,
diseases: str | list[str],
output_subdir: str | Path = "./",
fit_ed_visits: bool = False,
fit_hospital_admissions: bool = False,
fit_wastewater: bool = False,
forecast_ed_visits: bool = False,
forecast_hospital_admissions: bool = False,
forecast_wastewater: bool = False,
container_image_name: str = "pyrenew-hew",
container_image_version: str = "latest",
n_training_days: int = 90,
Expand Down Expand Up @@ -54,7 +60,25 @@ def main(
Subdirectory of the output blob storage container
in which to save results.
container_image_name:
fit_ed_visits
Fit to ED visits data? Default ``False``.
fit_hospital_admissions
Fit to hospital admissions data? Default ``False``.
fit_wastewater
Fit to wastewater data? Default ``False``.
forecast_ed_visits
Forecast ED visits? Default ``False``.
forecast_hospital_admissions
Forecast hospital admissions? Default ``False``.
forecast_wastewater
Forecast wastewater concentrations? Default ``False``.
container_image_name
Name of the container to use for the job.
This container should exist within the Azure
Container Registry account associated to
Expand Down Expand Up @@ -109,6 +133,24 @@ def main(
f"supported diseases are: {', '.join(supported_diseases)}"
)

signals = ["ed_visits", "hospital_admissions", "wastewater"]

for signal in signals:
fit = locals().get(f"fit_{signal}", False)
forecast = locals().get(f"forecast_{signal}", False)
if fit and not forecast:
raise ValueError(
"This pipeline does not currently support "
"fitting to but not forecasting a signal. "
f"Asked to fit but not forecast {signal}."
)
any_fit = any([locals().get(f"fit_{signal}", False) for signal in signals])
if not any_fit:
raise ValueError(
"pyrenew_null (fitting to no signals) "
"is not supported by this pipeline"
)

pyrenew_hew_output_container = (
"pyrenew-test-output" if test else "pyrenew-hew-prod-output"
)
Expand Down Expand Up @@ -155,14 +197,31 @@ def main(
],
)

needed_hew_flags = [
"fit_ed_visits",
"fit_hospital_admissions",
"fit_wastewater",
"forecast_ed_visits",
"forecast_hospital_admissions",
"forecast_wastewater",
]

def as_flag(flag_name, bool_val):
prefix = "" if bool_val else "no-"
return f"--{prefix}{re.sub("_", "-", flag_name)}"

hew_flags = " ".join(
[as_flag(k, v) for k, v in locals().items() if k in needed_hew_flags]
)

base_call = (
"/bin/bash -c '"
"python pipelines/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days {n_training_days} "
"--n-warmup {n_warmup} "
"--n-samples {n_samples} "
f"--n-training-days {n_training_days} "
f"--n-warmup {n_warmup} "
f"--n-samples {n_samples} "
"--facility-level-nssp-data-dir nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp-archival-vintages/gold "
Expand All @@ -171,8 +230,9 @@ def main(
"--priors-path pipelines/priors/prod_priors.py "
"--credentials-path config/creds.toml "
"--report-date {report_date} "
"--exclude-last-n-days {exclude_last_n_days} "
f"--exclude-last-n-days {exclude_last_n_days} "
"--no-score "
f"{hew_flags} "
"--eval-data-path "
"nssp-etl/latest_comprehensive.parquet"
"'"
Expand All @@ -197,10 +257,6 @@ def main(
state=state,
disease=disease,
report_date="latest",
n_warmup=n_warmup,
n_samples=n_samples,
n_training_days=n_training_days,
exclude_last_n_days=exclude_last_n_days,
output_dir=str(Path("output", output_subdir)),
),
container_settings=container_settings,
Expand All @@ -210,98 +266,141 @@ def main(
return None


parser = argparse.ArgumentParser()

parser.add_argument("job_id", type=str, help="Name for the Azure batch job")
parser.add_argument(
"pool_id",
type=str,
help=("Name of the Azure batch pool on which to run the job"),
)
parser.add_argument(
"--diseases",
type=str,
default="COVID-19 Influenza",
help=(
"Name(s) of disease(s) to run as part of the job, "
"as a whitespace-separated string. Supported "
"values are 'COVID-19' and 'Influenza'. "
"Default 'COVID-19 Influenza' (i.e. run for both)."
),
)

parser.add_argument(
"--output-subdir",
type=str,
help=(
"Subdirectory of the output blob storage container "
"in which to save results."
),
default="./",
)

parser.add_argument(
"--container-image-name",
type=str,
help="Name of the container to use for the job.",
default="pyrenew-hew",
)

parser.add_argument(
"--container-image-version",
type=str,
help="Version of the container to use for the job.",
default="latest",
)

parser.add_argument(
"--n-training-days",
type=int,
help=(
"Number of 'training days' of observed data "
"to use for model fitting."
),
default=90,
)

parser.add_argument(
"--exclude-last-n-days",
type=int,
help=(
"Number of days to drop from the end of the timeseries "
"of observed data when constructing the training data."
),
default=1,
)

parser.add_argument(
"--locations-include",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"include in the job, as a whitespace-separated "
"string. If not set, include all ",
"available locations except any explicitly excluded "
"via --locations-exclude.",
),
default=None,
)

parser.add_argument(
"--locations-exclude",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"exclude from the job, as a whitespace-separated "
"string. Defaults to a set of locations for which "
"we typically do not have available NSSP ED visit "
"data: 'AS GU MO MP PR UM VI WY'."
),
default="AS GU MO MP PR UM VI WY",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"job_id", type=str, help="Name for the Azure batch job"
)
parser.add_argument(
"pool_id",
type=str,
help=("Name of the Azure batch pool on which to run the job"),
)
parser.add_argument(
"--diseases",
type=str,
default="COVID-19 Influenza",
help=(
"Name(s) of disease(s) to run as part of the job, "
"as a whitespace-separated string. Supported "
"values are 'COVID-19' and 'Influenza'. "
"Default 'COVID-19 Influenza' (i.e. run for both)."
),
)

parser.add_argument(
"--output-subdir",
type=str,
help=(
"Subdirectory of the output blob storage container "
"in which to save results."
),
default="./",
)

parser.add_argument(
"--container-image-name",
type=str,
help="Name of the container to use for the job.",
default="pyrenew-hew",
)

parser.add_argument(
"--container-image-version",
type=str,
help="Version of the container to use for the job.",
default="latest",
)

parser.add_argument(
"--fit-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to ED visit data.",
)

parser.add_argument(
"--fit-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=("If provided, fit to hospital admissions data."),
)

parser.add_argument(
"--fit-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to wastewater data.",
)

parser.add_argument(
"--forecast-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, forecast ED visits.",
)

parser.add_argument(
"--forecast-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=("If provided, forecast hospital admissions."),
)

parser.add_argument(
"--forecast-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, forecast wastewater concentrations.",
)

parser.add_argument(
"--n-training-days",
type=int,
help=(
"Number of 'training days' of observed data "
"to use for model fitting."
),
default=90,
)

parser.add_argument(
"--exclude-last-n-days",
type=int,
help=(
"Number of days to drop from the end of the timeseries "
"of observed data when constructing the training data."
),
default=1,
)

parser.add_argument(
"--locations-include",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"include in the job, as a whitespace-separated "
"string. If not set, include all ",
"available locations except any explicitly excluded "
"via --locations-exclude.",
),
default=None,
)

parser.add_argument(
"--locations-exclude",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"exclude from the job, as a whitespace-separated "
"string. Defaults to a set of locations for which "
"we typically do not have available NSSP ED visit "
"data: 'AS GU MO MP PR UM VI WY'."
),
default="AS GU MO MP PR UM VI WY",
)

if __name__ == "__main__":
args = parser.parse_args()
args.diseases = args.diseases.split()
if args.locations_include is not None:
Expand Down
Loading

0 comments on commit 026f1eb

Please sign in to comment.