From d8900146c67a69d9b1d5e9237c664974d3e20ac0 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 23 Jan 2025 17:15:32 +0000 Subject: [PATCH] Make right truncation offset in forecast data None rather than 0 (#305) --- pyproject.toml | 2 +- pyrenew_hew/pyrenew_hew_data.py | 3 +- ...ss.py => test_latent_infection_process.py} | 0 tests/test_pyrenew_hew_data.py | 100 ++++++++++++++++++ 4 files changed, 103 insertions(+), 2 deletions(-) rename tests/{test_LatentInfectionProcess.py => test_latent_infection_process.py} (100%) create mode 100644 tests/test_pyrenew_hew_data.py diff --git a/pyproject.toml b/pyproject.toml index c705a0de..7d058eb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.12" -pyrenew = {git = "https://github.com/CDCgov/PyRenew"} +pyrenew = {git = "https://github.com/cdcgov/pyrenew"} ipywidgets = "^8.1.5" arviz = "^0.20.0" pyyaml = "^6.0.2" diff --git a/pyrenew_hew/pyrenew_hew_data.py b/pyrenew_hew/pyrenew_hew_data.py index 75ac5d1d..4c437d83 100644 --- a/pyrenew_hew/pyrenew_hew_data.py +++ b/pyrenew_hew/pyrenew_hew_data.py @@ -169,5 +169,6 @@ def to_forecast_data(self, n_forecast_points: int) -> Self: first_ed_visits_date=self.first_data_date_overall, first_hospital_admissions_date=(self.first_data_date_overall), first_wastewater_date=self.first_data_date_overall, - right_truncation_offset=0, + right_truncation_offset=None, + # by default, want forecasts of complete reports ) diff --git a/tests/test_LatentInfectionProcess.py b/tests/test_latent_infection_process.py similarity index 100% rename from tests/test_LatentInfectionProcess.py rename to tests/test_latent_infection_process.py diff --git a/tests/test_pyrenew_hew_data.py b/tests/test_pyrenew_hew_data.py new file mode 100644 index 00000000..90ebbfac --- /dev/null +++ b/tests/test_pyrenew_hew_data.py @@ -0,0 +1,100 @@ +from datetime import datetime + +import pytest + +from pyrenew_hew.pyrenew_hew_data import PyrenewHEWData + + +@pytest.mark.parametrize( + [ + "n_ed_visits_datapoints", + "n_hospital_admissions_datapoints", + "n_wastewater_datapoints", + "right_truncation_offset", + "first_ed_visits_date", + "first_hospital_admissions_date", + "first_wastewater_date", + "n_forecast_points", + ], + [ + [ + 50, + 0, + 0, + 5, + datetime(2023, 1, 1), + datetime(2022, 2, 5), + datetime(2025, 12, 5), + 10, + ], + [ + 0, + 325, + 2, + 5, + datetime(2025, 1, 1), + datetime(2023, 5, 25), + datetime(2022, 4, 5), + 10, + ], + [ + 0, + 0, + 2, + 3, + datetime(2025, 1, 1), + datetime(2025, 2, 5), + datetime(2024, 12, 5), + 30, + ], + [ + 0, + 0, + 23, + 3, + datetime(2025, 1, 1), + datetime(2025, 2, 5), + datetime(2024, 12, 5), + 30, + ], + ], +) +def test_to_forecast_data( + n_ed_visits_datapoints: int, + n_hospital_admissions_datapoints: int, + n_wastewater_datapoints: int, + right_truncation_offset: int, + first_ed_visits_date: datetime.date, + first_hospital_admissions_date: datetime.date, + first_wastewater_date: datetime.date, + n_forecast_points: int, +) -> None: + """ + Test the to_forecast_data method + """ + data = PyrenewHEWData( + n_ed_visits_datapoints=n_ed_visits_datapoints, + n_hospital_admissions_datapoints=n_hospital_admissions_datapoints, + n_wastewater_datapoints=n_wastewater_datapoints, + first_ed_visits_date=first_ed_visits_date, + first_hospital_admissions_date=first_hospital_admissions_date, + first_wastewater_date=first_wastewater_date, + right_truncation_offset=right_truncation_offset, + ) + + assert data.right_truncation_offset == right_truncation_offset + assert data.right_truncation_offset is not None + + forecast_data = data.to_forecast_data(n_forecast_points) + n_days_expected = data.n_days_post_init + n_forecast_points + n_weeks_expected = n_days_expected // 7 + assert forecast_data.n_ed_visits_datapoints == n_days_expected + assert forecast_data.n_wastewater_datapoints == n_days_expected + assert forecast_data.n_hospital_admissions_datapoints == n_weeks_expected + assert forecast_data.right_truncation_offset is None + assert forecast_data.first_ed_visits_date == data.first_data_date_overall + assert ( + forecast_data.first_hospital_admissions_date + == data.first_data_date_overall + ) + assert forecast_data.first_wastewater_date == data.first_data_date_overall