Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make which quantities are sampled more readily configurable #326

Merged
merged 37 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6e1c5b6
Make which quantities are sampled more readily configurable
dylanhmorris Feb 5, 2025
c3b7858
Fix variable name typo
dylanhmorris Feb 5, 2025
3d15528
Predictive flags and CLI
dylanhmorris Feb 6, 2025
0238435
Set flag in forecast state
dylanhmorris Feb 6, 2025
250ea74
Harmonize data generation
dylanhmorris Feb 6, 2025
557330d
Remove print call
dylanhmorris Feb 6, 2025
8baa0f0
Add switches to fitting scripts
dylanhmorris Feb 6, 2025
c40eabb
Add flags to prod job call
dylanhmorris Feb 6, 2025
674af5d
Fix flag ordering and logic
dylanhmorris Feb 6, 2025
18df0b5
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 10, 2025
1ef2244
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 10, 2025
098d568
More meaningful flag names
dylanhmorris Feb 10, 2025
85f692a
DRY flag setting
dylanhmorris Feb 10, 2025
31b6f1d
Align CLI with main function
dylanhmorris Feb 10, 2025
0c75437
Fix missing import
dylanhmorris Feb 10, 2025
ca6e01e
Update end-to-end test
dylanhmorris Feb 10, 2025
6252dfc
Fix typo
dylanhmorris Feb 10, 2025
54771ed
Add build model switch test
dylanhmorris Feb 10, 2025
db3564d
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 11, 2025
25eaca0
Add programmatic model naming
dylanhmorris Feb 11, 2025
d46b958
use lowercase, add unit test
dylanhmorris Feb 11, 2025
536ed3f
End to end test is only pyrenew_e for now
dylanhmorris Feb 11, 2025
c868636
Do not delete test check on additional data
dylanhmorris Feb 11, 2025
2362f42
Remove extra line
dylanhmorris Feb 11, 2025
3327a94
fix typo in end to end test
dylanhmorris Feb 11, 2025
b8dc990
Move argument parsing inside if __name__==__main__ in setup prod job,…
dylanhmorris Feb 11, 2025
e25aac3
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 11, 2025
26102b7
Qualify namespace in generate_predictive.py
dylanhmorris Feb 11, 2025
f0eef15
Revert import formulation
dylanhmorris Feb 11, 2025
283d5b0
Fix missing space
dylanhmorris Feb 11, 2025
2b3b881
Apply suggestions from code review
dylanhmorris Feb 11, 2025
0b4c4d7
Update pyrenew_hew/util.py
damonbayer Feb 11, 2025
85d1446
Add hew model iterator
dylanhmorris Feb 12, 2025
3661e5e
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 13, 2025
474d283
Add temporary checks for disallowed models
dylanhmorris Feb 13, 2025
94ebe3b
Fix typos
dylanhmorris Feb 13, 2025
58dfda3
Merge branch 'main' into dhm-configure-which-fit
dylanhmorris Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 199 additions & 100 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import argparse
import itertools
import re
from pathlib import Path

import polars as pl
from azure.batch import models
from azuretools.auth import EnvCredentialHandler
from azuretools.client import get_batch_service_client
Expand All @@ -21,6 +21,12 @@ def main(
pool_id: str,
diseases: str | list[str],
output_subdir: str | Path = "./",
fit_ed_visits: bool = False,
fit_hospital_admissions: bool = False,
fit_wastewater: bool = False,
forecast_ed_visits: bool = False,
forecast_hospital_admissions: bool = False,
forecast_wastewater: bool = False,
container_image_name: str = "pyrenew-hew",
container_image_version: str = "latest",
n_training_days: int = 90,
Expand Down Expand Up @@ -54,7 +60,25 @@ def main(
Subdirectory of the output blob storage container
in which to save results.

container_image_name:
fit_ed_visits
Fit to ED visits data? Default ``False``.

fit_hospital_admissions
Fit to hospital admissions data? Default ``False``.

fit_wastewater
Fit to wastewater data? Default ``False``.

forecast_ed_visits
Forecast ED visits? Default ``False``.

forecast_hospital_admissions
Forecast hospital admissions? Default ``False``.

forecast_wastewater
Forecast wastewater concentrations? Default ``False``.

container_image_name
Name of the container to use for the job.
This container should exist within the Azure
Container Registry account associated to
Expand Down Expand Up @@ -109,6 +133,24 @@ def main(
f"supported diseases are: {', '.join(supported_diseases)}"
)

signals = ["ed_visits", "hospital_admissions", "wastewater"]

for signal in signals:
fit = locals().get(f"fit_{signal}", False)
forecast = locals().get(f"forecast_{signal}", False)
if fit and not forecast:
raise ValueError(
"This pipeline does not currently support "
"fitting to but not forecasting a signal. "
f"Asked to fit but not forecast {signal}."
)
any_fit = any([locals().get(f"fit_{signal}", False) for signal in signals])
if not any_fit:
raise ValueError(
"pyrenew_null (fitting to no signals) "
"is not supported by this pipeline"
)

pyrenew_hew_output_container = (
"pyrenew-test-output" if test else "pyrenew-hew-prod-output"
)
Expand Down Expand Up @@ -155,14 +197,31 @@ def main(
],
)

needed_hew_flags = [
"fit_ed_visits",
"fit_hospital_admissions",
"fit_wastewater",
"forecast_ed_visits",
"forecast_hospital_admissions",
"forecast_wastewater",
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
]

def as_flag(flag_name, bool_val):
prefix = "" if bool_val else "no-"
return f"--{prefix}{re.sub("_", "-", flag_name)}"

hew_flags = " ".join(
[as_flag(k, v) for k, v in locals().items() if k in needed_hew_flags]
)

base_call = (
"/bin/bash -c '"
"python pipelines/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days {n_training_days} "
"--n-warmup {n_warmup} "
"--n-samples {n_samples} "
f"--n-training-days {n_training_days} "
f"--n-warmup {n_warmup} "
f"--n-samples {n_samples} "
"--facility-level-nssp-data-dir nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp-archival-vintages/gold "
Expand All @@ -171,8 +230,9 @@ def main(
"--priors-path pipelines/priors/prod_priors.py "
"--credentials-path config/creds.toml "
"--report-date {report_date} "
"--exclude-last-n-days {exclude_last_n_days} "
f"--exclude-last-n-days {exclude_last_n_days} "
"--no-score "
f"{hew_flags} "
"--eval-data-path "
"nssp-etl/latest_comprehensive.parquet"
"'"
Expand All @@ -197,10 +257,6 @@ def main(
state=state,
disease=disease,
report_date="latest",
n_warmup=n_warmup,
n_samples=n_samples,
n_training_days=n_training_days,
exclude_last_n_days=exclude_last_n_days,
output_dir=str(Path("output", output_subdir)),
),
container_settings=container_settings,
Expand All @@ -210,98 +266,141 @@ def main(
return None


parser = argparse.ArgumentParser()

parser.add_argument("job_id", type=str, help="Name for the Azure batch job")
parser.add_argument(
"pool_id",
type=str,
help=("Name of the Azure batch pool on which to run the job"),
)
parser.add_argument(
"--diseases",
type=str,
default="COVID-19 Influenza",
help=(
"Name(s) of disease(s) to run as part of the job, "
"as a whitespace-separated string. Supported "
"values are 'COVID-19' and 'Influenza'. "
"Default 'COVID-19 Influenza' (i.e. run for both)."
),
)

parser.add_argument(
"--output-subdir",
type=str,
help=(
"Subdirectory of the output blob storage container "
"in which to save results."
),
default="./",
)

parser.add_argument(
"--container-image-name",
type=str,
help="Name of the container to use for the job.",
default="pyrenew-hew",
)

parser.add_argument(
"--container-image-version",
type=str,
help="Version of the container to use for the job.",
default="latest",
)

parser.add_argument(
"--n-training-days",
type=int,
help=(
"Number of 'training days' of observed data "
"to use for model fitting."
),
default=90,
)

parser.add_argument(
"--exclude-last-n-days",
type=int,
help=(
"Number of days to drop from the end of the timeseries "
"of observed data when constructing the training data."
),
default=1,
)

parser.add_argument(
"--locations-include",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"include in the job, as a whitespace-separated "
"string. If not set, include all ",
"available locations except any explicitly excluded "
"via --locations-exclude.",
),
default=None,
)

parser.add_argument(
"--locations-exclude",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"exclude from the job, as a whitespace-separated "
"string. Defaults to a set of locations for which "
"we typically do not have available NSSP ED visit "
"data: 'AS GU MO MP PR UM VI WY'."
),
default="AS GU MO MP PR UM VI WY",
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"job_id", type=str, help="Name for the Azure batch job"
)
parser.add_argument(
"pool_id",
type=str,
help=("Name of the Azure batch pool on which to run the job"),
)
parser.add_argument(
"--diseases",
type=str,
default="COVID-19 Influenza",
help=(
"Name(s) of disease(s) to run as part of the job, "
"as a whitespace-separated string. Supported "
"values are 'COVID-19' and 'Influenza'. "
"Default 'COVID-19 Influenza' (i.e. run for both)."
),
)

parser.add_argument(
"--output-subdir",
type=str,
help=(
"Subdirectory of the output blob storage container "
"in which to save results."
),
default="./",
)

parser.add_argument(
"--container-image-name",
type=str,
help="Name of the container to use for the job.",
default="pyrenew-hew",
)

parser.add_argument(
"--container-image-version",
type=str,
help="Version of the container to use for the job.",
default="latest",
)

parser.add_argument(
"--fit-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to ED visit data.",
)

parser.add_argument(
"--fit-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=("If provided, fit to hospital admissions data."),
)

parser.add_argument(
"--fit-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, fit to wastewater data.",
)

parser.add_argument(
"--forecast-ed-visits",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, forecast ED visits.",
)

parser.add_argument(
"--forecast-hospital-admissions",
type=bool,
action=argparse.BooleanOptionalAction,
help=("If provided, forecast hospital admissions."),
)

parser.add_argument(
"--forecast-wastewater",
type=bool,
action=argparse.BooleanOptionalAction,
help="If provided, forecast wastewater concentrations.",
)

parser.add_argument(
"--n-training-days",
type=int,
help=(
"Number of 'training days' of observed data "
"to use for model fitting."
),
default=90,
)

parser.add_argument(
"--exclude-last-n-days",
type=int,
help=(
"Number of days to drop from the end of the timeseries "
"of observed data when constructing the training data."
),
default=1,
)

parser.add_argument(
"--locations-include",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"include in the job, as a whitespace-separated "
"string. If not set, include all ",
"available locations except any explicitly excluded "
"via --locations-exclude.",
),
default=None,
)

parser.add_argument(
"--locations-exclude",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"exclude from the job, as a whitespace-separated "
"string. Defaults to a set of locations for which "
"we typically do not have available NSSP ED visit "
"data: 'AS GU MO MP PR UM VI WY'."
),
default="AS GU MO MP PR UM VI WY",
)

if __name__ == "__main__":
args = parser.parse_args()
args.diseases = args.diseases.split()
if args.locations_include is not None:
Expand Down
Loading
Loading