Skip to content

Commit

Permalink
Add basic batch postprocessing script
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Dec 16, 2024
1 parent bd392dd commit 32df62e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 16 deletions.
6 changes: 4 additions & 2 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,14 @@ def main(
help=("Name of the Azure batch pool on which to run the job"),
)
parser.add_argument(
"diseases",
"--diseases",
type=str,
default="COVID-19 Influenza",
help=(
"Name(s) of disease(s) to run as part of the job, "
"as a whitespace-separated string. Supported "
"values are 'COVID-19' and 'Influenza'."
"values are 'COVID-19' and 'Influenza'. "
"Default 'COVID-19 Influenza' (i.e. run for both)."
),
)

Expand Down
42 changes: 31 additions & 11 deletions pipelines/postprocess_all_locations.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import argparse
import logging
import subprocess
from pathlib import Path

Check warning on line 3 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L1-L3

Added lines #L1 - L3 were not covered by tests

import collate_plots as cp
from utils import get_all_forecast_dirs, parse_model_batch_dir_name

Check warning on line 6 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L5-L6

Added lines #L5 - L6 were not covered by tests


def create_hubverse_table(model_batch_dir: Path) -> None:
def create_hubverse_table(base_path: Path, model_batch_dir: Path) -> None:
batch_info = parse_model_batch_dir_name(model_batch_dir)

Check warning on line 10 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L9-L10

Added lines #L9 - L10 were not covered by tests

output_file_name = (

Check warning on line 12 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L12

Added line #L12 was not covered by tests
Expand All @@ -16,13 +15,14 @@ def create_hubverse_table(model_batch_dir: Path) -> None:
"hubverse-table.tsv"
)

output_path = Path(model_batch_dir, output_file_name)
model_batch_path = Path(base_path, model_batch_dir)
output_path = Path(model_batch_path, output_file_name)

Check warning on line 19 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L18-L19

Added lines #L18 - L19 were not covered by tests

result = subprocess.run(

Check warning on line 21 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L21

Added line #L21 was not covered by tests
[
"Rscript",
"pipelines/create_hubverse_table.R",
f"{model_batch_dir}",
f"{model_batch_path}",
f"{output_path}",
],
capture_output=True,
Expand All @@ -32,15 +32,23 @@ def create_hubverse_table(model_batch_dir: Path) -> None:
return None

Check warning on line 32 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L30-L32

Added lines #L30 - L32 were not covered by tests


def process_model_batch_dir(model_batch_dir: Path) -> None:
cp.process_dir(model_batch_dir)
create_hubverse_table(model_batch_dir)
def process_model_batch_dir(base_dir: Path, model_batch_dir: Path) -> None:
plot_types = ["Disease", "Other", "prop_disease_ed_visits"]
plots_to_collate = [f"{x}_forecast_plot.pdf" for x in plot_types] + [

Check warning on line 37 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L35-L37

Added lines #L35 - L37 were not covered by tests
f"{x}_forecast_plot_log.pdf" for x in plot_types
]
cp.process_dir(

Check warning on line 40 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L40

Added line #L40 was not covered by tests
Path(base_dir, model_batch_dir), target_filenames=plots_to_collate
)
create_hubverse_table(base_dir, model_batch_dir)

Check warning on line 43 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L43

Added line #L43 was not covered by tests


def main(base_forecast_dir: Path):
to_process = get_all_forecast_dirs(base_forecast_dir)
def main(

Check warning on line 46 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L46

Added line #L46 was not covered by tests
base_forecast_dir: Path, diseases: list[str] = ["COVID-19", "Influenza"]
):
to_process = get_all_forecast_dirs(base_forecast_dir, diseases)
for batch_dir in to_process:
process_model_batch_dir(batch_dir)
process_model_batch_dir(base_forecast_dir, batch_dir)

Check warning on line 51 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L49-L51

Added lines #L49 - L51 were not covered by tests


if __name__ == "__main__":
Expand All @@ -50,9 +58,21 @@ def main(base_forecast_dir: Path):
parser.add_argument(

Check warning on line 58 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L58

Added line #L58 was not covered by tests
"base_forecast_dir",
type=Path,
required=True,
help="Directory containing forecast subdirectories.",
)

parser.add_argument(

Check warning on line 64 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L64

Added line #L64 was not covered by tests
"--diseases",
type=str,
default="COVID-19 Influenza",
help=(
"Name(s) of disease(s) to postprocess, "
"as a whitespace-separated string. Supported "
"values are 'COVID-19' and 'Influenza'. "
"Default 'COVID-19 Influenza' (i.e. postprocess both)."
),
)

args = parser.parse_args()
args.diseases = args.diseases.split()
main(**vars(args))

Check warning on line 78 in pipelines/postprocess_all_locations.py

View check run for this annotation

Codecov / codecov/patch

pipelines/postprocess_all_locations.py#L76-L78

Added lines #L76 - L78 were not covered by tests
6 changes: 3 additions & 3 deletions pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def parse_model_batch_dir_name(model_batch_dir_name):
)
return dict(
disease=disease_map_lower_[disease],
report_date=datetime.strptime(report_date, "%Y-%m-%d").date(),
first_training_date=datetime.strptime(
report_date=datetime.datetime.strptime(report_date, "%Y-%m-%d").date(),
first_training_date=datetime.datetime.strptime(
first_training_date, "%Y-%m-%d"
).date(),
last_training_date=datetime.strptime(
last_training_date=datetime.datetime.strptime(
last_training_date, "%Y-%m-%d"
).date(),
)
Expand Down

0 comments on commit 32df62e

Please sign in to comment.