diff --git a/pipelines/batch/setup_prod_job.py b/pipelines/batch/setup_prod_job.py index 84d4cb8b..ddbd4bb1 100644 --- a/pipelines/batch/setup_prod_job.py +++ b/pipelines/batch/setup_prod_job.py @@ -5,9 +5,9 @@ import argparse import itertools +import re from pathlib import Path -import polars as pl from azure.batch import models from azuretools.auth import EnvCredentialHandler from azuretools.client import get_batch_service_client @@ -21,6 +21,12 @@ def main( pool_id: str, diseases: str | list[str], output_subdir: str | Path = "./", + fit_ed_visits: bool = False, + fit_hospital_admissions: bool = False, + fit_wastewater: bool = False, + forecast_ed_visits: bool = False, + forecast_hospital_admissions: bool = False, + forecast_wastewater: bool = False, container_image_name: str = "pyrenew-hew", container_image_version: str = "latest", n_training_days: int = 90, @@ -54,7 +60,25 @@ def main( Subdirectory of the output blob storage container in which to save results. - container_image_name: + fit_ed_visits + Fit to ED visits data? Default ``False``. + + fit_hospital_admissions + Fit to hospital admissions data? Default ``False``. + + fit_wastewater + Fit to wastewater data? Default ``False``. + + forecast_ed_visits + Forecast ED visits? Default ``False``. + + forecast_hospital_admissions + Forecast hospital admissions? Default ``False``. + + forecast_wastewater + Forecast wastewater concentrations? Default ``False``. + + container_image_name Name of the container to use for the job. This container should exist within the Azure Container Registry account associated to @@ -109,6 +133,24 @@ def main( f"supported diseases are: {', '.join(supported_diseases)}" ) + signals = ["ed_visits", "hospital_admissions", "wastewater"] + + for signal in signals: + fit = locals().get(f"fit_{signal}", False) + forecast = locals().get(f"forecast_{signal}", False) + if fit and not forecast: + raise ValueError( + "This pipeline does not currently support " + "fitting to but not forecasting a signal. " + f"Asked to fit but not forecast {signal}." + ) + any_fit = any([locals().get(f"fit_{signal}", False) for signal in signals]) + if not any_fit: + raise ValueError( + "pyrenew_null (fitting to no signals) " + "is not supported by this pipeline" + ) + pyrenew_hew_output_container = ( "pyrenew-test-output" if test else "pyrenew-hew-prod-output" ) @@ -155,14 +197,31 @@ def main( ], ) + needed_hew_flags = [ + "fit_ed_visits", + "fit_hospital_admissions", + "fit_wastewater", + "forecast_ed_visits", + "forecast_hospital_admissions", + "forecast_wastewater", + ] + + def as_flag(flag_name, bool_val): + prefix = "" if bool_val else "no-" + return f"--{prefix}{re.sub("_", "-", flag_name)}" + + hew_flags = " ".join( + [as_flag(k, v) for k, v in locals().items() if k in needed_hew_flags] + ) + base_call = ( "/bin/bash -c '" "python pipelines/forecast_state.py " "--disease {disease} " "--state {state} " - "--n-training-days {n_training_days} " - "--n-warmup {n_warmup} " - "--n-samples {n_samples} " + f"--n-training-days {n_training_days} " + f"--n-warmup {n_warmup} " + f"--n-samples {n_samples} " "--facility-level-nssp-data-dir nssp-etl/gold " "--state-level-nssp-data-dir " "nssp-archival-vintages/gold " @@ -171,8 +230,9 @@ def main( "--priors-path pipelines/priors/prod_priors.py " "--credentials-path config/creds.toml " "--report-date {report_date} " - "--exclude-last-n-days {exclude_last_n_days} " + f"--exclude-last-n-days {exclude_last_n_days} " "--no-score " + f"{hew_flags} " "--eval-data-path " "nssp-etl/latest_comprehensive.parquet" "'" @@ -197,10 +257,6 @@ def main( state=state, disease=disease, report_date="latest", - n_warmup=n_warmup, - n_samples=n_samples, - n_training_days=n_training_days, - exclude_last_n_days=exclude_last_n_days, output_dir=str(Path("output", output_subdir)), ), container_settings=container_settings, @@ -210,98 +266,141 @@ def main( return None -parser = argparse.ArgumentParser() - -parser.add_argument("job_id", type=str, help="Name for the Azure batch job") -parser.add_argument( - "pool_id", - type=str, - help=("Name of the Azure batch pool on which to run the job"), -) -parser.add_argument( - "--diseases", - type=str, - default="COVID-19 Influenza", - help=( - "Name(s) of disease(s) to run as part of the job, " - "as a whitespace-separated string. Supported " - "values are 'COVID-19' and 'Influenza'. " - "Default 'COVID-19 Influenza' (i.e. run for both)." - ), -) - -parser.add_argument( - "--output-subdir", - type=str, - help=( - "Subdirectory of the output blob storage container " - "in which to save results." - ), - default="./", -) - -parser.add_argument( - "--container-image-name", - type=str, - help="Name of the container to use for the job.", - default="pyrenew-hew", -) - -parser.add_argument( - "--container-image-version", - type=str, - help="Version of the container to use for the job.", - default="latest", -) - -parser.add_argument( - "--n-training-days", - type=int, - help=( - "Number of 'training days' of observed data " - "to use for model fitting." - ), - default=90, -) - -parser.add_argument( - "--exclude-last-n-days", - type=int, - help=( - "Number of days to drop from the end of the timeseries " - "of observed data when constructing the training data." - ), - default=1, -) - -parser.add_argument( - "--locations-include", - type=str, - help=( - "Two-letter USPS location abbreviations to " - "include in the job, as a whitespace-separated " - "string. If not set, include all ", - "available locations except any explicitly excluded " - "via --locations-exclude.", - ), - default=None, -) - -parser.add_argument( - "--locations-exclude", - type=str, - help=( - "Two-letter USPS location abbreviations to " - "exclude from the job, as a whitespace-separated " - "string. Defaults to a set of locations for which " - "we typically do not have available NSSP ED visit " - "data: 'AS GU MO MP PR UM VI WY'." - ), - default="AS GU MO MP PR UM VI WY", -) +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "job_id", type=str, help="Name for the Azure batch job" + ) + parser.add_argument( + "pool_id", + type=str, + help=("Name of the Azure batch pool on which to run the job"), + ) + parser.add_argument( + "--diseases", + type=str, + default="COVID-19 Influenza", + help=( + "Name(s) of disease(s) to run as part of the job, " + "as a whitespace-separated string. Supported " + "values are 'COVID-19' and 'Influenza'. " + "Default 'COVID-19 Influenza' (i.e. run for both)." + ), + ) + + parser.add_argument( + "--output-subdir", + type=str, + help=( + "Subdirectory of the output blob storage container " + "in which to save results." + ), + default="./", + ) + + parser.add_argument( + "--container-image-name", + type=str, + help="Name of the container to use for the job.", + default="pyrenew-hew", + ) + + parser.add_argument( + "--container-image-version", + type=str, + help="Version of the container to use for the job.", + default="latest", + ) + + parser.add_argument( + "--fit-ed-visits", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, fit to ED visit data.", + ) + + parser.add_argument( + "--fit-hospital-admissions", + type=bool, + action=argparse.BooleanOptionalAction, + help=("If provided, fit to hospital admissions data."), + ) + + parser.add_argument( + "--fit-wastewater", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, fit to wastewater data.", + ) + + parser.add_argument( + "--forecast-ed-visits", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, forecast ED visits.", + ) + + parser.add_argument( + "--forecast-hospital-admissions", + type=bool, + action=argparse.BooleanOptionalAction, + help=("If provided, forecast hospital admissions."), + ) + + parser.add_argument( + "--forecast-wastewater", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, forecast wastewater concentrations.", + ) + + parser.add_argument( + "--n-training-days", + type=int, + help=( + "Number of 'training days' of observed data " + "to use for model fitting." + ), + default=90, + ) + + parser.add_argument( + "--exclude-last-n-days", + type=int, + help=( + "Number of days to drop from the end of the timeseries " + "of observed data when constructing the training data." + ), + default=1, + ) + + parser.add_argument( + "--locations-include", + type=str, + help=( + "Two-letter USPS location abbreviations to " + "include in the job, as a whitespace-separated " + "string. If not set, include all ", + "available locations except any explicitly excluded " + "via --locations-exclude.", + ), + default=None, + ) + + parser.add_argument( + "--locations-exclude", + type=str, + help=( + "Two-letter USPS location abbreviations to " + "exclude from the job, as a whitespace-separated " + "string. Defaults to a set of locations for which " + "we typically do not have available NSSP ED visit " + "data: 'AS GU MO MP PR UM VI WY'." + ), + default="AS GU MO MP PR UM VI WY", + ) -if __name__ == "__main__": args = parser.parse_args() args.diseases = args.diseases.split() if args.locations_include is not None: diff --git a/pipelines/build_pyrenew_model.py b/pipelines/build_pyrenew_model.py index 0019e29f..951761b2 100644 --- a/pipelines/build_pyrenew_model.py +++ b/pipelines/build_pyrenew_model.py @@ -18,6 +18,9 @@ def build_model_from_dir( model_dir: Path, + sample_ed_visits: bool = False, + sample_hospital_admissions: bool = False, + sample_wastewater: bool = False, ) -> tuple[PyrenewHEWModel, PyrenewHEWData]: data_path = Path(model_dir) / "data" / "data_for_model_fit.json" prior_path = Path(model_dir) / "priors.py" @@ -47,12 +50,20 @@ def build_model_from_dir( jnp.array(model_data["generation_interval_pmf"]), ) # check if off by 1 or reversed - data_observed_disease_ed_visits = jnp.array( - model_data["data_observed_disease_ed_visits"] + data_observed_disease_ed_visits = ( + jnp.array(model_data["data_observed_disease_ed_visits"]) + if sample_ed_visits + else None ) - data_observed_disease_hospital_admissions = jnp.array( - model_data["data_observed_disease_hospital_admissions"] + data_observed_disease_hospital_admissions = ( + jnp.array(model_data["data_observed_disease_hospital_admissions"]) + if sample_hospital_admissions + else None ) + + # placeholder + data_observed_disease_wastewater = None if sample_wastewater else None + population_size = jnp.array(model_data["state_pop"]) ed_right_truncation_pmf_rv = DeterministicVariable( @@ -133,7 +144,7 @@ def build_model_from_dir( data_observed_disease_hospital_admissions=( data_observed_disease_hospital_admissions ), - data_observed_disease_wastewater=None, # placeholder + data_observed_disease_wastewater=data_observed_disease_wastewater, right_truncation_offset=right_truncation_offset, first_ed_visits_date=first_ed_visits_date, first_hospital_admissions_date=first_hospital_admissions_date, diff --git a/pipelines/fit_pyrenew_model.py b/pipelines/fit_pyrenew_model.py index 33fb86b9..9fa952ab 100644 --- a/pipelines/fit_pyrenew_model.py +++ b/pipelines/fit_pyrenew_model.py @@ -12,6 +12,9 @@ def fit_and_save_model( model_run_dir: str, model_name: str, + fit_ed_visits: bool = False, + fit_hospital_admissions: bool = False, + fit_wastewater: bool = False, n_warmup: int = 1000, n_samples: int = 1000, n_chains: int = 4, @@ -26,12 +29,17 @@ def fit_and_save_model( "rng_key must be an integer with which " "to seed :func:`jax.random.key`" ) - (my_model, my_data) = build_model_from_dir(model_run_dir) + (my_model, my_data) = build_model_from_dir( + model_run_dir, + sample_ed_visits=fit_ed_visits, + sample_hospital_admissions=fit_hospital_admissions, + sample_wastewater=fit_wastewater, + ) my_model.run( data=my_data, - sample_ed_visits=True, - sample_hospital_admissions=True, - sample_wastewater=False, + sample_ed_visits=fit_ed_visits, + sample_hospital_admissions=fit_hospital_admissions, + sample_wastewater=fit_wastewater, num_warmup=n_warmup, num_samples=n_samples, rng_key=rng_key, @@ -67,6 +75,26 @@ def fit_and_save_model( required=True, help="Name of the model to use for generating predictions.", ) + + parser.add_argument( + "--fit-ed-visits", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, fit to ED visit data.", + ) + parser.add_argument( + "--fit-hospital-admissions", + type=bool, + action=argparse.BooleanOptionalAction, + help=("If provided, fit to hospital admissions data."), + ) + parser.add_argument( + "--fit-wastewater", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, fit to wastewater data.", + ) + parser.add_argument( "--n-warmup", type=int, diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index a3bd4f48..da873c03 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -14,6 +14,8 @@ from prep_eval_data import save_eval_data from pygit2 import Repository +from pyrenew_hew.util import pyrenew_model_name_from_flags + numpyro.set_host_device_count(4) from fit_pyrenew_model import fit_and_save_model # noqa @@ -202,10 +204,45 @@ def main( score: bool = False, eval_data_path: Path = None, credentials_path: Path = None, -): + fit_ed_visits: bool = False, + fit_hospital_admissions: bool = False, + fit_wastewater: bool = False, + forecast_ed_visits: bool = False, + forecast_hospital_admissions: bool = False, + forecast_wastewater: bool = False, +) -> None: logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + pyrenew_model_name = pyrenew_model_name_from_flags( + fit_ed_visits=fit_ed_visits, + fit_hospital_admissions=fit_hospital_admissions, + fit_wastewater=fit_wastewater, + ) + + logger.info( + "Starting single-location forecasting pipeline for " + f"model {pyrenew_model_name}, location {state}, " + f"and report date {report_date}" + ) + signals = ["ed_visits", "hospital_admissions", "wastewater"] + + for signal in signals: + fit = locals().get(f"fit_{signal}", False) + forecast = locals().get(f"forecast_{signal}", False) + if fit and not forecast: + raise ValueError( + "This pipeline does not currently support " + "fitting to but not forecasting a signal. " + f"Asked to fit but not forecast {signal}." + ) + any_fit = any([locals().get(f"fit_{signal}", False) for signal in signals]) + if not any_fit: + raise ValueError( + "pyrenew_null (fitting to no signals) " + "is not supported by this pipeline" + ) + if credentials_path is not None: cp = Path(credentials_path) if not cp.suffix.lower() == ".toml": @@ -351,10 +388,13 @@ def main( logger.info("Fitting model") fit_and_save_model( model_run_dir, - "pyrenew_e", + pyrenew_model_name, n_warmup=n_warmup, n_samples=n_samples, n_chains=n_chains, + fit_ed_visits=fit_ed_visits, + fit_hospital_admissions=fit_hospital_admissions, + fit_wastewater=fit_wastewater, ) logger.info("Model fitting complete") @@ -362,7 +402,12 @@ def main( n_days_past_last_training = n_forecast_days + exclude_last_n_days generate_and_save_predictions( - model_run_dir, "pyrenew_e", n_days_past_last_training + model_run_dir, + pyrenew_model_name, + n_days_past_last_training, + predict_ed_visits=forecast_ed_visits, + predict_hospital_admissions=forecast_hospital_admissions, + predict_wastewater=forecast_wastewater, ) logger.info( @@ -379,11 +424,13 @@ def main( logger.info("All forecasting complete.") logger.info("Converting inferencedata to parquet...") - convert_inferencedata_to_parquet(model_run_dir, "pyrenew_e") + convert_inferencedata_to_parquet(model_run_dir, pyrenew_model_name) logger.info("Conversion complete.") logger.info("Postprocessing forecast...") - plot_and_save_state_forecast(model_run_dir, "pyrenew_e", "timeseries_e") + plot_and_save_state_forecast( + model_run_dir, pyrenew_model_name, "timeseries_e" + ) logger.info("Postprocessing complete.") logger.info("Rendering webpage...") @@ -395,8 +442,9 @@ def main( score_forecast(model_run_dir) logger.info( - "Single state pipeline complete " - f"for state {state} with " + "Single-location pipeline complete " + f"for model {pyrenew_model_name}, " + f"location {state}, and " f"report date {report_date}." ) return None @@ -548,6 +596,45 @@ def main( type=Path, help=("Path to a parquet file containing compehensive truth data."), ) + + parser.add_argument( + "--fit-ed-visits", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, fit to ED visit data.", + ) + parser.add_argument( + "--fit-hospital-admissions", + type=bool, + action=argparse.BooleanOptionalAction, + help=("If provided, fit to hospital admissions data."), + ) + parser.add_argument( + "--fit-wastewater", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, fit to wastewater data.", + ) + + parser.add_argument( + "--forecast-ed-visits", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, forecast ED visits.", + ) + parser.add_argument( + "--forecast-hospital-admissions", + type=bool, + action=argparse.BooleanOptionalAction, + help=("If provided, forecast hospital admissions."), + ) + parser.add_argument( + "--forecast-wastewater", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, forecast wastewater concentrations.", + ) + args = parser.parse_args() numpyro.set_host_device_count(args.n_chains) main(**vars(args)) diff --git a/pipelines/generate_predictive.py b/pipelines/generate_predictive.py index d9fea1e3..f1a319a0 100644 --- a/pipelines/generate_predictive.py +++ b/pipelines/generate_predictive.py @@ -9,13 +9,23 @@ def generate_and_save_predictions( - model_run_dir: str | Path, model_name: str, n_forecast_points: int + model_run_dir: str | Path, + model_name: str, + n_forecast_points: int, + predict_ed_visits: bool = False, + predict_hospital_admissions: bool = False, + predict_wastewater: bool = False, ) -> None: model_run_dir = Path(model_run_dir) model_dir = Path(model_run_dir, model_name) if not model_dir.exists(): raise FileNotFoundError(f"The directory {model_dir} does not exist.") - (my_model, my_data) = build_model_from_dir(model_run_dir) + (my_model, my_data) = build_model_from_dir( + model_run_dir, + sample_ed_visits=predict_ed_visits, + sample_hospital_admissions=predict_hospital_admissions, + sample_wastewater=predict_wastewater, + ) my_model._init_model(1, 1) fresh_sampler = my_model.mcmc.sampler @@ -31,9 +41,9 @@ def generate_and_save_predictions( posterior_predictive = my_model.posterior_predictive( data=forecast_data, - sample_ed_visits=True, - sample_hospital_admissions=True, - sample_wastewater=False, + sample_ed_visits=predict_ed_visits, + sample_hospital_admissions=predict_hospital_admissions, + sample_wastewater=predict_wastewater, ) idata = az.from_numpyro( @@ -73,6 +83,28 @@ def generate_and_save_predictions( default=0, help="Number of time points to forecast (Default: 0).", ) + parser.add_argument( + "--predict-ed-visits", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, generate posterior predictions for ED visits.", + ) + parser.add_argument( + "--predict-hospital-admissions", + type=bool, + action=argparse.BooleanOptionalAction, + help=( + "If provided, generate posterior predictions " + "for hospital admissions." + ), + ) + parser.add_argument( + "--predict-wastewater", + type=bool, + action=argparse.BooleanOptionalAction, + help="If provided, generate posterior predictions for wastewater.", + ) + args = parser.parse_args() generate_and_save_predictions(**vars(args)) diff --git a/pipelines/tests/test_build_pyrenew_model.py b/pipelines/tests/test_build_pyrenew_model.py new file mode 100644 index 00000000..e8b5cdf1 --- /dev/null +++ b/pipelines/tests/test_build_pyrenew_model.py @@ -0,0 +1,90 @@ +import json +from pathlib import Path + +import jax.numpy as jnp +import pytest + +from pipelines.build_pyrenew_model import build_model_from_dir + + +@pytest.fixture +def mock_data(): + return json.dumps( + { + "data_observed_disease_ed_visits": [1, 2, 3], + "data_observed_disease_hospital_admissions": [4, 5, 6], + "state_pop": [7, 8, 9], + "generation_interval_pmf": [0.1, 0.2, 0.7], + "inf_to_ed_pmf": [0.4, 0.5, 0.1], + "right_truncation_pmf": [0.7, 0.1, 0.2], + "nssp_training_dates": ["2025-01-01"], + "nhsn_training_dates": ["2025-01-02"], + "right_truncation_offset": 10, + } + ) + + +@pytest.fixture +def mock_priors(): + return """ +from pyrenew.deterministic import NullVariable + +i0_first_obs_n_rv = None +initialization_rate_rv = None +log_r_mu_intercept_rv = None +autoreg_rt_rv = None +eta_sd_rv = None +inf_feedback_strength_rv = NullVariable() +p_ed_visit_mean_rv = None +p_ed_visit_w_sd_rv = None +autoreg_p_ed_visit_rv = None +ed_visit_wday_effect_rv = None +ed_neg_bin_concentration_rv = None +hosp_admit_neg_bin_concentration_rv = None +ihr_rv = None +t_peak_rv = None +duration_shed_after_peak_rv = None +log10_genome_per_inf_ind_rv = None +mode_sigma_ww_site_rv = None +sd_log_sigma_ww_site_rv = None +mode_sd_ww_site_rv = None +max_shed_interval = None +ww_ml_produced_per_day = None +""" + + +def test_build_model_from_dir(tmp_path, mock_data, mock_priors): + model_dir = tmp_path / "model_dir" + data_dir = model_dir / "data" + data_dir.mkdir(parents=True) + + (model_dir / "priors.py").write_text(mock_priors) + + data_path = data_dir / "data_for_model_fit.json" + data_path.write_text(mock_data) + + model_data = json.loads(mock_data) + + # Test when all sample arguments are False + _, data = build_model_from_dir(model_dir) + assert data.data_observed_disease_ed_visits is None + assert data.data_observed_disease_hospital_admissions is None + assert data.data_observed_disease_wastewater is None + + # Test when all sample arguments are True + _, data = build_model_from_dir( + model_dir, + sample_ed_visits=True, + sample_hospital_admissions=True, + sample_wastewater=True, + ) + assert jnp.array_equal( + data.data_observed_disease_ed_visits, + jnp.array(model_data["data_observed_disease_ed_visits"]), + ) + assert jnp.array_equal( + data.data_observed_disease_hospital_admissions, + jnp.array(model_data["data_observed_disease_hospital_admissions"]), + ) + assert data.data_observed_disease_wastewater is None + ## Update this if wastewater data is added later diff --git a/pipelines/tests/test_end_to_end.sh b/pipelines/tests/test_end_to_end.sh index 62e3f583..309bb8b4 100755 --- a/pipelines/tests/test_end_to_end.sh +++ b/pipelines/tests/test_end_to_end.sh @@ -33,6 +33,12 @@ do --n-chains 2 \ --n-samples 250 \ --n-warmup 250 \ + --fit-ed-visits \ + --no-fit-hospital-admissions \ + --no-fit-wastewater \ + --forecast-ed-visits \ + --forecast-hospital-admissions \ + --no-forecast-wastewater \ --score \ --eval-data-path "$BASE_DIR/private_data/nssp-etl" if [ $? -ne 0 ]; then diff --git a/pyrenew_hew/util.py b/pyrenew_hew/util.py new file mode 100644 index 00000000..a4bff42b --- /dev/null +++ b/pyrenew_hew/util.py @@ -0,0 +1,111 @@ +""" +Pyrenew-HEW utilities +""" + +from itertools import chain, combinations +from typing import Iterable + + +def powerset(iterable: Iterable) -> Iterable: + """ + Subsequences of the iterable from shortest to longest, + considering only unique elements. + + Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes + """ + s = set(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +def hew_models(with_null: bool = True) -> Iterable: + """ + Return an iterable of the Pyrenew-HEW models + as their lowercase letters. + + Parameters + ---------- + with_null + Include the null model ("pyrenew_null"), represented as + the empty tuple `()`? Default ``True``. + + Returns + ------- + Iterable + An iterable yielding tuples of model letters. + """ + result = powerset(("h", "e", "w")) + if not with_null: + result = filter(None, result) + return result + + +def hew_letters_from_flags( + fit_ed_visits: bool = False, + fit_hospital_admissions: bool = False, + fit_wastewater: bool = False, +) -> str: + """ + Get the {h, e, w} letters defining + a model from a set of flags indicating which + of the datastreams, if any, were used in fitting. + If none of them were, return the string "null" + + Parameters + ---------- + fit_ed_visits + ED visit data used in fitting? + + fit_hospital_admissions + Hospital admissions data used in fitting? + + fit_wastewater + Wastewater data used in fitting? + + Returns + ------- + str + The relevant HEW letters, or 'null', + a""" + result = ( + f"{'h' if fit_hospital_admissions else ''}" + f"{'e' if fit_ed_visits else ''}" + f"{'w' if fit_wastewater else ''}" + ) + if not result: + result = "null" + return result + + +def pyrenew_model_name_from_flags( + fit_ed_visits: bool = False, + fit_hospital_admissions: bool = False, + fit_wastewater: bool = False, +) -> str: + """ + Get a "pyrenew_{h,e,w}" model name + string from a set of flags indicating which + of the datastreams, if any, were used in fitting. + If none of them were, call the model "pyrenew_null" + + Parameters + ---------- + fit_ed_visits + ED visit data used in fitting? + + fit_hospital_admissions + Hospital admissions data used in fitting? + + fit_wastewater + Wastewater data used in fitting? + + Returns + ------- + str + The model name. + """ + hew_letters = hew_letters_from_flags( + fit_ed_visits=fit_ed_visits, + fit_hospital_admissions=fit_hospital_admissions, + fit_wastewater=fit_wastewater, + ) + return f"pyrenew_{hew_letters}" diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 00000000..9628edb3 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,94 @@ +from typing import Iterable + +import pytest + +from pyrenew_hew.util import ( + hew_letters_from_flags, + hew_models, + powerset, + pyrenew_model_name_from_flags, +) + + +@pytest.mark.parametrize( + [ + "fit_ed_visits", + "fit_hospital_admissions", + "fit_wastewater", + "expected_letters", + ], + [ + (False, False, False, "null"), + (False, True, False, "h"), + (True, False, False, "e"), + (False, False, True, "w"), + (True, True, False, "he"), + (False, True, True, "hw"), + (True, False, True, "ew"), + (True, True, True, "hew"), + ], +) +def test_hew_naming_from_flags( + fit_ed_visits, fit_hospital_admissions, fit_wastewater, expected_letters +): + assert ( + hew_letters_from_flags( + fit_ed_visits, fit_hospital_admissions, fit_wastewater + ) + == expected_letters + ) + + assert ( + pyrenew_model_name_from_flags( + fit_ed_visits, fit_hospital_admissions, fit_wastewater + ) + == f"pyrenew_{expected_letters}" + ) + + +@pytest.mark.parametrize( + "test_items", + [ + range(10), + ["a", "b", "c"], + [None, "a", "b"], + [None, None, "a", "b"], + ["a", "b", "a", "a"], + [1, 1, 1.5, 2], + ], +) +def test_powerset(test_items): + pset_iter = powerset(test_items) + pset = set(pset_iter) + assert isinstance(pset_iter, Iterable) + assert set([(item,) for item in test_items]).issubset(pset) + assert len(pset) == 2 ** len(set(test_items)) + assert () in pset + + +def test_hew_model_iterator(): + expected = [ + (), + ("h",), + ("e",), + ("w",), + ( + "e", + "w", + ), + ( + "h", + "e", + ), + ( + "h", + "w", + ), + ("h", "e", "w"), + ] + assert set([tuple(sorted(i)) for i in hew_models()]) == set( + [tuple(sorted(i)) for i in expected] + ) + assert set([tuple(sorted(i)) for i in hew_models(False)]) == set( + [tuple(sorted(i)) for i in filter(None, expected)] + )