Skip to content

Commit

Permalink
Merge in code to run and score retrospective evaluation (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris authored Nov 8, 2024
1 parent d7e7e16 commit b905057
Show file tree
Hide file tree
Showing 26 changed files with 2,432 additions and 252 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
*.bin
*.xls
*.xlsx
*.rds

# Documents
*.doc
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ repos:
#####
# Basic file cleanliness
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-added-large-files
- id: check-yaml
Expand All @@ -13,7 +13,7 @@ repos:
#####
# Python
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.8
rev: v0.7.1
hooks:
# Sort imports
- id: ruff
Expand All @@ -26,13 +26,13 @@ repos:
#####
# R
- repo: https://github.com/lorenzwalthert/precommit
rev: v0.4.3
rev: v0.4.3.9001
hooks:
- id: style-files
- id: lintr
# Secrets
- repo: https://github.com/Yelp/detect-secrets
rev: v1.4.0
rev: v1.5.0
hooks:
- id: detect-secrets
args: ["--baseline", ".secrets.baseline"]
Expand Down
3 changes: 2 additions & 1 deletion Containerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ WORKDIR pyrenew-hew
COPY .ContainerBuildRprofile .Rprofile

RUN Rscript -e "install.packages('pak')"
RUN Rscript -e "pak::pkg_install('cmu-delphi/epiprocess@main')"
RUN Rscript -e "pak::pkg_install('cmu-delphi/epipredict@main')"
RUN Rscript -e "pak::local_install('hewr')"

COPY . .

RUN pip install --root-user-action=ignore .
4 changes: 3 additions & 1 deletion hewr/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ License: `use_mit_license()`, `use_gpl3_license()` or friends to pick a
license
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Imports:
argparser,
arrow,
cowplot,
dplyr,
fable,
feasts,
forcats,
fs,
ggplot2,
glue,
Expand All @@ -25,6 +26,7 @@ Imports:
purrr,
readr,
scales,
scoringutils (>= 2.0.0),
stringr,
tibble,
tidybayes,
Expand Down
63 changes: 63 additions & 0 deletions hewr/R/parse_path.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
disease_map_lower <- list(
"covid-19" = "COVID-19",
"influenza" = "Influenza"
)

#' Parse the name of a model batch directory
#' (i.e. a directory representing a single
#' report date and disease pair, but potentially
#' with fits for multiple locations), returning
#' a named list of quantities of interest.
#'
#' @param model_batch_dir_name Name of the model batch
#' directory (not the full path to it, just the directory
#' base name) to parse.
#' @return A list of quantities: `disease`, `report_date`,
#' `first_training_date`, and `last_training_date`.
#' @export
parse_model_batch_dir <- function(model_batch_dir_name) {
pattern <- "(.+)_r_(.+)_f_(.+)_t_(.+)"

matches <- stringr::str_match(
model_batch_dir_name,
pattern
)

if (is.na(matches[1])) {
stop(
"Invalid format for model batch directory name; ",
"could not parse. Expected ",
"'<disease>_r_<report_date>_f_<first_training_date>_t_",
"<last_training_date>'."
)
}

return(list(
disease = disease_map_lower[[matches[2]]],
report_date = lubridate::ymd(matches[3]),
first_training_date = lubridate::ymd(matches[4]),
last_training_date = lubridate::ymd(matches[5])
))
}

#' Parse path to a model run directory
#' (i.e. a directory representing a run for a
#' particular location, disease, and reference
#' date, and extract key quantities of interest.
#'
#' @param model_run_dir_path Path to parse.
#' @return A list of parsed attributes:
#' `location`, `disease`, `report_date`,
#' `first_training_date`, and `last_training_date`.
#'
#' @export
parse_model_run_dir <- function(model_run_dir_path) {
batch_dir <- fs::path_dir(model_run_dir_path) |>
fs::path_file()
location <- fs::path_file(model_run_dir_path)

return(c(
list(location = location),
parse_model_batch_dir(batch_dir)
))
}
28 changes: 0 additions & 28 deletions nssp_demo/all_post_process.sh

This file was deleted.

219 changes: 219 additions & 0 deletions nssp_demo/batch/setup_eval_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""
Set up a multi-location, multi-date,
potentially multi-disease end to end
retrospective evaluation run for pyrenew-hew
on Azure Batch.
"""

import argparse
import datetime
import itertools

import polars as pl
from azure.batch import models
from azuretools.auth import EnvCredentialHandler
from azuretools.client import get_batch_service_client
from azuretools.job import create_job_if_not_exists
from azuretools.task import get_container_settings, get_task_config


def main(
job_id: str,
pool_id: str,
diseases: str,
container_image_name: str = "pyrenew-hew",
container_image_version: str = "latest",
excluded_locations: list[str] = [
"AS",
"GU",
"MO",
"MP",
"PR",
"UM",
"VI",
"WY",
],
) -> None:
"""
job_id
Name for the Batch job.
pool_id
Azure Batch pool on which to run the job.
diseases
Name(s) of disease(s) to run as part of the job,
as a whitespace-separated string. Supported
values are 'COVID-19' and 'Influenza'.
container_image_name:
Name of the container to use for the job.
This container should exist within the Azure
Container Registry account associated to
the job. Default 'pyrenew-hew'.
The container registry account name and endpoint
will be obtained from local environment variables
via a :class``azuretools.auth.EnvCredentialHandler`.
container_image_version
Version of the container to use. Default 'latest'.
excluded_locations
List of two letter USPS location abbreviations to
exclude from the job. Defaults to locations for which
we typically do not have available NSSP ED visit data:
``["AS", "GU", "MO", "MP", "PR", "UM", "VI", "WY"]``.
Returns
-------
None
"""
supported_diseases = ["COVID-19", "Influenza"]

disease_list = diseases.split()
invalid_diseases = set(disease_list) - set(supported_diseases)
if invalid_diseases:
raise ValueError(
f"Unsupported diseases: {', '.join(invalid_diseases)}; "
f"supported diseases are: {', '.join(supported_diseases)}"
)

creds = EnvCredentialHandler()
client = get_batch_service_client(creds)
job = models.JobAddParameter(
id=job_id,
pool_info=models.PoolInformation(pool_id=pool_id),
)
create_job_if_not_exists(client, job, verbose=True)

container_image = (
f"{creds.azure_container_registry_account}."
f"{creds.azure_container_registry_domain}/"
f"{container_image_name}:{container_image_version}"
)
container_settings = get_container_settings(
container_image,
working_directory="containerImageDefault",
mount_pairs=[
{
"source": "nssp-etl",
"target": "/pyrenew-hew/nssp_demo/nssp-etl",
},
{
"source": "nssp-archival-vintages",
"target": "/pyrenew-hew/nssp_demo/nssp-archival-vintages",
},
{
"source": "prod-param-estimates",
"target": "/pyrenew-hew/nssp_demo/params",
},
{
"source": "pyrenew-test-output",
"target": "/pyrenew-hew/nssp_demo/private_data",
},
],
)

base_call = (
"/bin/bash -c '"
"python nssp_demo/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days 365 "
"--n-warmup 1000 "
"--n-samples 500 "
"--facility-level-nssp-data-dir nssp_demo/nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp_demo/nssp-archival-vintages/gold "
"--param-data-dir nssp_demo/params "
"--output-data-dir nssp_demo/private_data "
"--report-date {report_date:%Y-%m-%d} "
"--exclude-last-n-days 2 "
"--score "
"--eval-data-path "
"nssp_demo/nssp-archival-vintages/latest_comprehensive.parquet"
"'"
)

locations = pl.read_csv(
"https://www2.census.gov/geo/docs/reference/state.txt", separator="|"
)

all_locations = (
locations.filter(~pl.col("STUSAB").is_in(excluded_locations))
.get_column("STUSAB")
.to_list()
)

report_dates = [
datetime.date(2023, 10, 11) + datetime.timedelta(weeks=x)
for x in range(30)
]

for disease, report_date, loc in itertools.product(
disease_list, report_dates, all_locations
):
task = get_task_config(
f"{job_id}-{loc}-{disease}-{report_date}",
base_call=base_call.format(
state=loc,
disease=disease,
report_date=report_date,
),
container_settings=container_settings,
)
client.task.add(job_id, task)

return None


parser = argparse.ArgumentParser()

parser.add_argument("job_id", type=str, help="Name for the Azure batch job")
parser.add_argument(
"pool_id",
type=str,
help=("Name of the Azure batch pool on which to run the job"),
)
parser.add_argument(
"diseases",
type=str,
help=(
"Name(s) of disease(s) to run as part of the job, "
"as a whitespace-separated string. Supported "
"values are 'COVID-19' and 'Influenza'."
),
)

parser.add_argument(
"--container-image-name",
type=str,
help="Name of the container to use for the job.",
default="pyrenew-hew",
)

parser.add_argument(
"--container-image-version",
type=str,
help="Version of the container to use for the job.",
default="latest",
)

parser.add_argument(
"--excluded-locations",
type=str,
help=(
"Two-letter USPS location abbreviations to "
"exclude from the job, as a whitespace-separated "
"string. Defaults to a set of locations for which "
"we typically do not have available NSSP ED visit "
"data: 'AS GU MO MP PR UM VI WY'."
),
default="AS GU MO MP PR UM VI WY",
)


if __name__ == "__main__":
args = parser.parse_args()
args.excluded_locations = args.excluded_locations.split()
main(**vars(args))
Loading

0 comments on commit b905057

Please sign in to comment.