-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge in code to run and score retrospective evaluation (#102)
- Loading branch information
1 parent
d7e7e16
commit b905057
Showing
26 changed files
with
2,432 additions
and
252 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
*.bin | ||
*.xls | ||
*.xlsx | ||
*.rds | ||
|
||
# Documents | ||
*.doc | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
disease_map_lower <- list( | ||
"covid-19" = "COVID-19", | ||
"influenza" = "Influenza" | ||
) | ||
|
||
#' Parse the name of a model batch directory | ||
#' (i.e. a directory representing a single | ||
#' report date and disease pair, but potentially | ||
#' with fits for multiple locations), returning | ||
#' a named list of quantities of interest. | ||
#' | ||
#' @param model_batch_dir_name Name of the model batch | ||
#' directory (not the full path to it, just the directory | ||
#' base name) to parse. | ||
#' @return A list of quantities: `disease`, `report_date`, | ||
#' `first_training_date`, and `last_training_date`. | ||
#' @export | ||
parse_model_batch_dir <- function(model_batch_dir_name) { | ||
pattern <- "(.+)_r_(.+)_f_(.+)_t_(.+)" | ||
|
||
matches <- stringr::str_match( | ||
model_batch_dir_name, | ||
pattern | ||
) | ||
|
||
if (is.na(matches[1])) { | ||
stop( | ||
"Invalid format for model batch directory name; ", | ||
"could not parse. Expected ", | ||
"'<disease>_r_<report_date>_f_<first_training_date>_t_", | ||
"<last_training_date>'." | ||
) | ||
} | ||
|
||
return(list( | ||
disease = disease_map_lower[[matches[2]]], | ||
report_date = lubridate::ymd(matches[3]), | ||
first_training_date = lubridate::ymd(matches[4]), | ||
last_training_date = lubridate::ymd(matches[5]) | ||
)) | ||
} | ||
|
||
#' Parse path to a model run directory | ||
#' (i.e. a directory representing a run for a | ||
#' particular location, disease, and reference | ||
#' date, and extract key quantities of interest. | ||
#' | ||
#' @param model_run_dir_path Path to parse. | ||
#' @return A list of parsed attributes: | ||
#' `location`, `disease`, `report_date`, | ||
#' `first_training_date`, and `last_training_date`. | ||
#' | ||
#' @export | ||
parse_model_run_dir <- function(model_run_dir_path) { | ||
batch_dir <- fs::path_dir(model_run_dir_path) |> | ||
fs::path_file() | ||
location <- fs::path_file(model_run_dir_path) | ||
|
||
return(c( | ||
list(location = location), | ||
parse_model_batch_dir(batch_dir) | ||
)) | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
""" | ||
Set up a multi-location, multi-date, | ||
potentially multi-disease end to end | ||
retrospective evaluation run for pyrenew-hew | ||
on Azure Batch. | ||
""" | ||
|
||
import argparse | ||
import datetime | ||
import itertools | ||
|
||
import polars as pl | ||
from azure.batch import models | ||
from azuretools.auth import EnvCredentialHandler | ||
from azuretools.client import get_batch_service_client | ||
from azuretools.job import create_job_if_not_exists | ||
from azuretools.task import get_container_settings, get_task_config | ||
|
||
|
||
def main( | ||
job_id: str, | ||
pool_id: str, | ||
diseases: str, | ||
container_image_name: str = "pyrenew-hew", | ||
container_image_version: str = "latest", | ||
excluded_locations: list[str] = [ | ||
"AS", | ||
"GU", | ||
"MO", | ||
"MP", | ||
"PR", | ||
"UM", | ||
"VI", | ||
"WY", | ||
], | ||
) -> None: | ||
""" | ||
job_id | ||
Name for the Batch job. | ||
pool_id | ||
Azure Batch pool on which to run the job. | ||
diseases | ||
Name(s) of disease(s) to run as part of the job, | ||
as a whitespace-separated string. Supported | ||
values are 'COVID-19' and 'Influenza'. | ||
container_image_name: | ||
Name of the container to use for the job. | ||
This container should exist within the Azure | ||
Container Registry account associated to | ||
the job. Default 'pyrenew-hew'. | ||
The container registry account name and endpoint | ||
will be obtained from local environment variables | ||
via a :class``azuretools.auth.EnvCredentialHandler`. | ||
container_image_version | ||
Version of the container to use. Default 'latest'. | ||
excluded_locations | ||
List of two letter USPS location abbreviations to | ||
exclude from the job. Defaults to locations for which | ||
we typically do not have available NSSP ED visit data: | ||
``["AS", "GU", "MO", "MP", "PR", "UM", "VI", "WY"]``. | ||
Returns | ||
------- | ||
None | ||
""" | ||
supported_diseases = ["COVID-19", "Influenza"] | ||
|
||
disease_list = diseases.split() | ||
invalid_diseases = set(disease_list) - set(supported_diseases) | ||
if invalid_diseases: | ||
raise ValueError( | ||
f"Unsupported diseases: {', '.join(invalid_diseases)}; " | ||
f"supported diseases are: {', '.join(supported_diseases)}" | ||
) | ||
|
||
creds = EnvCredentialHandler() | ||
client = get_batch_service_client(creds) | ||
job = models.JobAddParameter( | ||
id=job_id, | ||
pool_info=models.PoolInformation(pool_id=pool_id), | ||
) | ||
create_job_if_not_exists(client, job, verbose=True) | ||
|
||
container_image = ( | ||
f"{creds.azure_container_registry_account}." | ||
f"{creds.azure_container_registry_domain}/" | ||
f"{container_image_name}:{container_image_version}" | ||
) | ||
container_settings = get_container_settings( | ||
container_image, | ||
working_directory="containerImageDefault", | ||
mount_pairs=[ | ||
{ | ||
"source": "nssp-etl", | ||
"target": "/pyrenew-hew/nssp_demo/nssp-etl", | ||
}, | ||
{ | ||
"source": "nssp-archival-vintages", | ||
"target": "/pyrenew-hew/nssp_demo/nssp-archival-vintages", | ||
}, | ||
{ | ||
"source": "prod-param-estimates", | ||
"target": "/pyrenew-hew/nssp_demo/params", | ||
}, | ||
{ | ||
"source": "pyrenew-test-output", | ||
"target": "/pyrenew-hew/nssp_demo/private_data", | ||
}, | ||
], | ||
) | ||
|
||
base_call = ( | ||
"/bin/bash -c '" | ||
"python nssp_demo/forecast_state.py " | ||
"--disease {disease} " | ||
"--state {state} " | ||
"--n-training-days 365 " | ||
"--n-warmup 1000 " | ||
"--n-samples 500 " | ||
"--facility-level-nssp-data-dir nssp_demo/nssp-etl/gold " | ||
"--state-level-nssp-data-dir " | ||
"nssp_demo/nssp-archival-vintages/gold " | ||
"--param-data-dir nssp_demo/params " | ||
"--output-data-dir nssp_demo/private_data " | ||
"--report-date {report_date:%Y-%m-%d} " | ||
"--exclude-last-n-days 2 " | ||
"--score " | ||
"--eval-data-path " | ||
"nssp_demo/nssp-archival-vintages/latest_comprehensive.parquet" | ||
"'" | ||
) | ||
|
||
locations = pl.read_csv( | ||
"https://www2.census.gov/geo/docs/reference/state.txt", separator="|" | ||
) | ||
|
||
all_locations = ( | ||
locations.filter(~pl.col("STUSAB").is_in(excluded_locations)) | ||
.get_column("STUSAB") | ||
.to_list() | ||
) | ||
|
||
report_dates = [ | ||
datetime.date(2023, 10, 11) + datetime.timedelta(weeks=x) | ||
for x in range(30) | ||
] | ||
|
||
for disease, report_date, loc in itertools.product( | ||
disease_list, report_dates, all_locations | ||
): | ||
task = get_task_config( | ||
f"{job_id}-{loc}-{disease}-{report_date}", | ||
base_call=base_call.format( | ||
state=loc, | ||
disease=disease, | ||
report_date=report_date, | ||
), | ||
container_settings=container_settings, | ||
) | ||
client.task.add(job_id, task) | ||
|
||
return None | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument("job_id", type=str, help="Name for the Azure batch job") | ||
parser.add_argument( | ||
"pool_id", | ||
type=str, | ||
help=("Name of the Azure batch pool on which to run the job"), | ||
) | ||
parser.add_argument( | ||
"diseases", | ||
type=str, | ||
help=( | ||
"Name(s) of disease(s) to run as part of the job, " | ||
"as a whitespace-separated string. Supported " | ||
"values are 'COVID-19' and 'Influenza'." | ||
), | ||
) | ||
|
||
parser.add_argument( | ||
"--container-image-name", | ||
type=str, | ||
help="Name of the container to use for the job.", | ||
default="pyrenew-hew", | ||
) | ||
|
||
parser.add_argument( | ||
"--container-image-version", | ||
type=str, | ||
help="Version of the container to use for the job.", | ||
default="latest", | ||
) | ||
|
||
parser.add_argument( | ||
"--excluded-locations", | ||
type=str, | ||
help=( | ||
"Two-letter USPS location abbreviations to " | ||
"exclude from the job, as a whitespace-separated " | ||
"string. Defaults to a set of locations for which " | ||
"we typically do not have available NSSP ED visit " | ||
"data: 'AS GU MO MP PR UM VI WY'." | ||
), | ||
default="AS GU MO MP PR UM VI WY", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
args.excluded_locations = args.excluded_locations.split() | ||
main(**vars(args)) |
Oops, something went wrong.