diff --git a/pipelines/hubverse_create_observed_data_tables.py b/pipelines/hubverse_create_observed_data_tables.py index 660e613f..873a5028 100644 --- a/pipelines/hubverse_create_observed_data_tables.py +++ b/pipelines/hubverse_create_observed_data_tables.py @@ -44,7 +44,7 @@ def save_observed_data_tables( first_training_date=datetime.date(2023, 1, 1), state_pop_df=state_pop, ), - ["US"] + [x for x in state_pop["abb"]], + state_pop.get_column("abb").to_list(), ) ).filter(pl.col("disease") == disease) for disease in ["COVID-19", "Influenza", "RSV", "Total"] diff --git a/pipelines/prep_data.py b/pipelines/prep_data.py index 455b6d60..a18c49cc 100644 --- a/pipelines/prep_data.py +++ b/pipelines/prep_data.py @@ -147,10 +147,15 @@ def process_state_level_data( disease_key = _disease_map.get(disease, disease) if state_abb == "US": + locations_to_aggregate = ( + state_pop_df.filter(pl.col("abb") != "US") + .get_column("abb") + .unique() + ) logger.info("Aggregating state-level data to national") state_level_nssp_data = aggregate_to_national( state_level_nssp_data, - state_pop_df["abb"].unique(), + locations_to_aggregate, first_training_date, national_geo_value="US", ) @@ -205,9 +210,14 @@ def aggregate_facility_level_nssp_to_state( if state_abb == "US": logger.info("Aggregating facility-level data to national") + locations_to_aggregate = ( + state_pop_df.filter(pl.col("abb") != "US") + .get_column("abb") + .unique() + ) facility_level_nssp_data = aggregate_to_national( facility_level_nssp_data, - state_pop_df["abb"].unique(), + locations_to_aggregate, first_training_date, national_geo_value="US", ) @@ -247,20 +257,12 @@ def verify_no_date_gaps(df: pl.DataFrame): def get_state_pop_df(): - census_dat = pl.read_csv( - "https://raw.githubusercontent.com/k5cents/usa/master/data-raw/facts.csv" - ).rename({"name": "long_name"}) - - state_pop_df = forecasttools.location_table.join( - census_dat, on="long_name", how="right" - ).select( + return forecasttools.location_table.select( pl.col("short_name").alias("abb"), pl.col("long_name").alias("name"), pl.col("population"), ) - return state_pop_df - def get_pmfs(param_estimates: pl.LazyFrame, state_abb: str, disease: str): generation_interval_pmf = ( @@ -327,14 +329,9 @@ def process_and_save_state( state_pop_df = get_state_pop_df() - if state_abb == "US": - state_pop = state_pop_df["population"].sum() - else: - state_pop = ( - state_pop_df.filter(pl.col("abb") == state_abb) - .get_column("population") - .to_list()[0] - ) + state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item( + 0, "population" + ) (generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs( param_estimates=param_estimates, state_abb=state_abb, disease=disease diff --git a/pipelines/tests/test_prep_data.py b/pipelines/tests/test_prep_data.py new file mode 100644 index 00000000..1f8c5585 --- /dev/null +++ b/pipelines/tests/test_prep_data.py @@ -0,0 +1,13 @@ +from pipelines import prep_data + + +def test_get_state_pop_df(): + """ + Confirm get_state_pop_df() + returns a polars data frame + with the expected number of rows + and expected column names + """ + df = prep_data.get_state_pop_df() + assert df.height == 58 # 50 states, 7 other jursidictions, US national + assert set(df.columns) == set(["name", "abb", "population"]) diff --git a/pyproject.toml b/pyproject.toml index 7d058eb5..bdbc5c9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ pypdf = "^5.1.0" pyarrow = "^18.0.0" pygit2 = "^1.17.0" azuretools = {git = "https://github.com/cdcgov/cfa-azuretools"} -forecasttools = {git = "https://github.com/CDCgov/forecasttools-py"} +forecasttools = {git = "https://github.com/cdcgov/forecasttools-py"} tomli-w = "^1.1.0" [tool.poetry.group.test]