Skip to content

Commit 1792e66

Browse files
authored
H,E,W component pipeline run tests (#371)
1 parent b4e6da8 commit 1792e66

File tree

8 files changed

+255
-334
lines changed

8 files changed

+255
-334
lines changed

.github/workflows/pipeline-run-check.yaml

Lines changed: 110 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,39 +9,136 @@ concurrency:
99
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
1010
cancel-in-progress: true
1111

12+
env:
13+
BASE_DIR: "pipelines/end_to_end_test_output"
14+
DATA_DIR: "pipelines/end_to_end_test_output/private_data"
15+
1216
jobs:
13-
run-pipeline:
14-
strategy:
15-
matrix:
16-
os: [ubuntu-22.04, macos-latest]
17-
runs-on: ${{matrix.os}}
17+
generate-data:
18+
runs-on: ubuntu-latest
19+
steps:
20+
- uses: actions/checkout@v4
21+
- name: Set up R
22+
uses: r-lib/actions/setup-r@v2
23+
with:
24+
r-version: "release"
25+
use-public-rspm: true
26+
- name: "Set up dependencies for hewr"
27+
uses: r-lib/actions/setup-r-dependencies@v2
28+
with:
29+
working-directory: hewr
30+
- name: Install hewr
31+
run: pak::local_install("hewr", ask = FALSE)
32+
shell: Rscript {0}
33+
- name: Generate test data
34+
run: Rscript pipelines/generate_test_data.R ${{ env.DATA_DIR }}
35+
36+
- name: Upload artifact
37+
uses: actions/upload-artifact@v4
38+
with:
39+
name: test-data
40+
path: ${{ env.DATA_DIR }}
41+
retention-days: 1
1842

43+
fit-models:
44+
needs: generate-data
45+
runs-on: ubuntu-latest
1946
env:
2047
NHSN_API_KEY_ID: ${{ secrets.NHSN_API_KEY_ID }}
2148
NHSN_API_KEY_SECRET: ${{ secrets.NHSN_API_KEY_SECRET }}
49+
strategy:
50+
matrix:
51+
model: [h, e, he, hw, ew, hew]
52+
disease: [COVID-19, Influenza]
53+
location: [US, CA, MT]
54+
exclude:
55+
- model: hw
56+
disease: Influenza
57+
- model: ew
58+
disease: Influenza
59+
- model: hew
60+
disease: Influenza
61+
- model: hw
62+
location: US
63+
- model: ew
64+
location: US
65+
- model: hew
66+
location: US
2267
steps:
2368
- uses: actions/checkout@v4
2469
- name: Set up python
2570
uses: actions/setup-python@v5
2671
with:
2772
python-version: "3.12"
28-
- name: "Set up R"
73+
- name: Set up R
2974
uses: r-lib/actions/setup-r@v2
3075
with:
3176
r-version: "release"
3277
use-public-rspm: true
33-
- name: "Set up Quarto"
78+
- name: Set up Quarto
3479
uses: quarto-dev/quarto-actions/setup@v2
35-
- name: "Install poetry"
80+
- name: Install poetry
3681
run: pip install poetry
37-
- name: "Install pyrenew-hew"
82+
- name: Install pyrenew-hew
3883
run: poetry install
39-
- name: "Set up dependencies for hewr"
84+
- name: Set up dependencies for hewr
85+
uses: r-lib/actions/setup-r-dependencies@v2
86+
with:
87+
working-directory: hewr
88+
- name: Install hewr
89+
run: pak::local_install("hewr", ask = FALSE)
90+
shell: Rscript {0}
91+
- name: Download test data
92+
uses: actions/download-artifact@v4
93+
with:
94+
name: test-data
95+
path: ${{ env.DATA_DIR }}
96+
- name: Fit model
97+
run: |
98+
poetry run bash pipelines/tests/test_fit.sh ${{ env.BASE_DIR }} \
99+
${{ matrix.disease }} ${{ matrix.location }} ${{ matrix.model }}
100+
- name: Upload artifact
101+
uses: actions/upload-artifact@v4
102+
with:
103+
name: |
104+
test-fit-${{ matrix.disease }}-${{ matrix.location }}-${{ matrix.model }}
105+
path: ${{ env.BASE_DIR }}
106+
107+
postprocess-models:
108+
needs: fit-models
109+
runs-on: ubuntu-latest
110+
steps:
111+
- uses: actions/checkout@v4
112+
- name: Set up python
113+
uses: actions/setup-python@v5
114+
with:
115+
python-version: "3.12"
116+
- name: Set up R
117+
uses: r-lib/actions/setup-r@v2
118+
with:
119+
r-version: "release"
120+
use-public-rspm: true
121+
- name: Set up Quarto
122+
uses: quarto-dev/quarto-actions/setup@v2
123+
- name: Install poetry
124+
run: pip install poetry
125+
- name: Install pyrenew-hew
126+
run: poetry install
127+
- name: Set up dependencies for hewr
40128
uses: r-lib/actions/setup-r-dependencies@v2
41129
with:
42130
working-directory: hewr
43-
- name: "Install hewr"
131+
- name: Install hewr
44132
run: pak::local_install("hewr", ask = FALSE)
45133
shell: Rscript {0}
46-
- name: "Run pipeline"
47-
run: poetry run bash pipelines/tests/test_end_to_end.sh pipelines/tests
134+
- name: Download fitting output
135+
uses: actions/download-artifact@v4
136+
with:
137+
pattern: test-fit-*
138+
path: ${{ env.BASE_DIR }}
139+
merge-multiple: true
140+
- name: Run postprocessing
141+
run: |
142+
poetry run python pipelines/postprocess_forecast_batches.py \
143+
${{ env.DATA_DIR }} \
144+
${{ env.DATA_DIR }}/nssp-etl/latest_comprehensive.parquet

.gitignore

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,5 @@ private_data/*
400400
*_files/
401401
.vscode/settings.json
402402

403-
# Test data exceptions to the general data exclusion
404-
!pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data/data.tsv
405-
!pipelines/tests/covid-19_r_2024-01-29_f_2023-11-01_t_2024-01-29/model_runs/TD/data/eval_data.tsv
406-
407-
# Ignore test pipe output
408-
pipelines/tests/private_data/*
403+
# Ignore end to end test output
404+
pipelines/tests/end_to_end_test_output/*

pipelines/forecast_state.py

Lines changed: 34 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from prep_eval_data import save_eval_data
1515
from pygit2 import Repository
1616

17-
from pyrenew_hew.util import pyrenew_model_name_from_flags
17+
from pyrenew_hew.util import (
18+
flags_from_hew_letters,
19+
pyrenew_model_name_from_flags,
20+
)
1821

1922
numpyro.set_host_device_count(4)
2023

@@ -317,22 +320,20 @@ def main(
317320
facility_level_nssp_data, state_level_nssp_data = None, None
318321

319322
if report_date in available_facility_level_reports:
320-
logger.info(
321-
"Facility level data available for " "the given report date"
322-
)
323+
logger.info("Facility level data available for the given report date")
323324
facility_datafile = f"{report_date}.parquet"
324325
facility_level_nssp_data = pl.scan_parquet(
325326
Path(facility_level_nssp_data_dir, facility_datafile)
326327
)
327328
if state_report_date in available_state_level_reports:
328-
logger.info("State-level data available for the given report " "date.")
329+
logger.info("State-level data available for the given report date.")
329330
state_datafile = f"{state_report_date}.parquet"
330331
state_level_nssp_data = pl.scan_parquet(
331332
Path(state_level_nssp_data_dir, state_datafile)
332333
)
333334
if facility_level_nssp_data is None and state_level_nssp_data is None:
334335
raise ValueError(
335-
"No data available for the requested report date " f"{report_date}"
336+
f"No data available for the requested report date {report_date}"
336337
)
337338

338339
nwss_data_disease_map = {
@@ -516,6 +517,15 @@ def get_available_nwss_reports(
516517
),
517518
)
518519

520+
parser.add_argument(
521+
"--model-letters",
522+
type=str,
523+
help=(
524+
"Fit the model corresponding to the provided model letters (e.g. 'he', 'e', 'hew')."
525+
),
526+
required=True,
527+
)
528+
519529
parser.add_argument(
520530
"--report-date",
521531
type=str,
@@ -528,8 +538,7 @@ def get_available_nwss_reports(
528538
type=Path,
529539
default=Path("private_data", "nssp_etl_gold"),
530540
help=(
531-
"Directory in which to look for facility-level NSSP "
532-
"ED visit data"
541+
"Directory in which to look for facility-level NSSP ED visit data"
533542
),
534543
)
535544

@@ -538,7 +547,7 @@ def get_available_nwss_reports(
538547
type=Path,
539548
default=Path("private_data", "nssp_state_level_gold"),
540549
help=(
541-
"Directory in which to look for state-level NSSP " "ED visit data."
550+
"Directory in which to look for state-level NSSP ED visit data."
542551
),
543552
)
544553

@@ -612,7 +621,7 @@ def get_available_nwss_reports(
612621
type=int,
613622
default=1000,
614623
help=(
615-
"Number of warmup iterations per chain for NUTS" "(default: 1000)."
624+
"Number of warmup iterations per chain for NUTS (default: 1000)."
616625
),
617626
)
618627

@@ -648,45 +657,23 @@ def get_available_nwss_reports(
648657
type=Path,
649658
help=("Path to a parquet file containing compehensive truth data."),
650659
)
651-
652-
parser.add_argument(
653-
"--fit-ed-visits",
654-
type=bool,
655-
action=argparse.BooleanOptionalAction,
656-
help="If provided, fit to ED visit data.",
657-
)
658-
parser.add_argument(
659-
"--fit-hospital-admissions",
660-
type=bool,
661-
action=argparse.BooleanOptionalAction,
662-
help=("If provided, fit to hospital admissions data."),
663-
)
664660
parser.add_argument(
665-
"--fit-wastewater",
666-
type=bool,
667-
action=argparse.BooleanOptionalAction,
668-
help="If provided, fit to wastewater data.",
669-
)
670-
671-
parser.add_argument(
672-
"--forecast-ed-visits",
673-
type=bool,
674-
action=argparse.BooleanOptionalAction,
675-
help="If provided, forecast ED visits.",
676-
)
677-
parser.add_argument(
678-
"--forecast-hospital-admissions",
679-
type=bool,
680-
action=argparse.BooleanOptionalAction,
681-
help=("If provided, forecast hospital admissions."),
682-
)
683-
parser.add_argument(
684-
"--forecast-wastewater",
685-
type=bool,
686-
action=argparse.BooleanOptionalAction,
687-
help="If provided, forecast wastewater concentrations.",
661+
"--additional-forecast-letters",
662+
type=str,
663+
help=(
664+
"Forecast the following signals even if they were not fit. "
665+
"Fit signals are always forecast."
666+
),
667+
default="he",
688668
)
689669

690670
args = parser.parse_args()
691671
numpyro.set_host_device_count(args.n_chains)
692-
main(**vars(args))
672+
fit_flags = flags_from_hew_letters(args.model_letters)
673+
forecast_flags = flags_from_hew_letters(
674+
args.model_letters + args.additional_forecast_letters,
675+
flag_prefix="forecast",
676+
)
677+
delattr(args, "model_letters")
678+
delattr(args, "additional_forecast_letters")
679+
main(**vars(args), **fit_flags, **forecast_flags)

0 commit comments

Comments
 (0)