Skip to content

Commit

Permalink
Add api key for nhsn to batch pipeline (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris authored Feb 12, 2025
1 parent b80ff42 commit 1c6ada6
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
1 change: 1 addition & 0 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def main(
"--param-data-dir params "
"--output-dir {output_dir} "
"--priors-path pipelines/priors/prod_priors.py "
"--credentials-path config/creds.toml "
"--report-date {report_date} "
"--exclude-last-n-days {exclude_last_n_days} "
"--no-score "
Expand Down
24 changes: 24 additions & 0 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,26 @@ def main(
exclude_last_n_days: int = 0,
score: bool = False,
eval_data_path: Path = None,
credentials_path: Path = None,
):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

if credentials_path is not None:
cp = Path(credentials_path)
if not cp.suffix.lower() == ".toml":
raise ValueError(
"Credentials file must have the extension "
"'.toml' (not case-sensitive). Got "
f"{cp.suffix}"
)
logger.info(f"Reading in credentials from {cp}...")
with open(cp, "rb") as fp:
credentials_dict = tomllib.load(fp)
else:
logger.info("No credentials file given. Will proceed without one.")
credentials_dict = None

available_facility_level_reports = get_available_reports(
facility_level_nssp_data_dir
)
Expand Down Expand Up @@ -310,6 +326,7 @@ def main(
param_estimates=param_estimates,
model_run_dir=model_run_dir,
logger=logger,
credentials_dict=credentials_dict,
)
logger.info("Getting eval data...")
if eval_data_path is None:
Expand All @@ -322,6 +339,7 @@ def main(
latest_comprehensive_path=eval_data_path,
output_data_dir=Path(model_run_dir, "data"),
last_eval_date=report_date + timedelta(days=n_forecast_days),
credentials_dict=credentials_dict,
)
logger.info("Done getting eval data.")

Expand Down Expand Up @@ -452,6 +470,12 @@ def main(
required=True,
)

parser.add_argument(
"--credentials-path",
type=Path,
help=("Path to a TOML file containing credentials such as API keys."),
)

parser.add_argument(
"--output-dir",
type=Path,
Expand Down
17 changes: 14 additions & 3 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ def get_nhsn(
end_date: datetime.date,
disease: str,
state_abb: str,
temp_dir=None,
temp_dir: Path = None,
credentials_dict: dict = None,
) -> None:
if temp_dir is None:
temp_dir = tempfile.mkdtemp()
if credentials_dict is None:
credentials_dict = dict()

def py_scalar_to_r_scalar(py_scalar):
if py_scalar is None:
Expand All @@ -43,14 +46,20 @@ def py_scalar_to_r_scalar(py_scalar):
state_abb_for_query = state_abb if state_abb != "US" else "USA"

temp_file = Path(temp_dir, "nhsn_temp.parquet")
api_key_id = credentials_dict.get(
"nhsn_api_key_id", os.getenv("NHSN_API_KEY_ID")
)
api_key_secret = credentials_dict.get(
"nhsn_api_key_secret", os.getenv("NHSN_API_KEY_SECRET")
)

r_command = [
"Rscript",
"-e",
f"""
forecasttools::pull_nhsn(
api_key_id = {py_scalar_to_r_scalar(os.getenv("NHSN_API_KEY_ID"))},
api_key_secret = {py_scalar_to_r_scalar(os.getenv("NHSN_API_KEY_SECRET"))},
api_key_id = {py_scalar_to_r_scalar(api_key_id)},
api_key_secret = {py_scalar_to_r_scalar(api_key_secret)},
start_date = {py_scalar_to_r_scalar(start_date)},
end_date = {py_scalar_to_r_scalar(end_date)},
columns = {py_scalar_to_r_scalar(columns)},
Expand Down Expand Up @@ -345,6 +354,7 @@ def process_and_save_state(
logger: Logger = None,
facility_level_nssp_data: pl.LazyFrame = None,
state_level_nssp_data: pl.LazyFrame = None,
credentials_dict: dict = None,
) -> None:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -406,6 +416,7 @@ def process_and_save_state(
end_date=last_training_date,
disease=disease,
state_abb=state_abb,
credentials_dict=credentials_dict,
).with_columns(pl.lit("train").alias("data_type"))

nssp_training_dates = (
Expand Down
2 changes: 2 additions & 0 deletions pipelines/prep_eval_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def save_eval_data(
output_data_dir: Path | str,
last_eval_date: datetime.date = None,
output_file_name: str = "eval_data.tsv",
credentials_dict: dict = None,
):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -49,6 +50,7 @@ def save_eval_data(
end_date=None,
disease=disease,
state_abb=state,
credentials_dict=credentials_dict,
).with_columns(data_type=pl.lit("eval"))

combined_eval_dat = combine_nssp_and_nhsn(
Expand Down

0 comments on commit 1c6ada6

Please sign in to comment.