Skip to content

Commit

Permalink
Use population sizes from forecasttools.location_table now that the…
Browse files Browse the repository at this point in the history
…y are provided (#324)
  • Loading branch information
dylanhmorris authored Feb 7, 2025
1 parent 318a2d4 commit 3a0a5bf
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pipelines/hubverse_create_observed_data_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def save_observed_data_tables(
first_training_date=datetime.date(2023, 1, 1),
state_pop_df=state_pop,
),
["US"] + [x for x in state_pop["abb"]],
state_pop.get_column("abb").to_list(),
)
).filter(pl.col("disease") == disease)
for disease in ["COVID-19", "Influenza", "RSV", "Total"]
Expand Down
35 changes: 16 additions & 19 deletions pipelines/prep_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,15 @@ def process_state_level_data(
disease_key = _disease_map.get(disease, disease)

if state_abb == "US":
locations_to_aggregate = (
state_pop_df.filter(pl.col("abb") != "US")
.get_column("abb")
.unique()
)
logger.info("Aggregating state-level data to national")
state_level_nssp_data = aggregate_to_national(
state_level_nssp_data,
state_pop_df["abb"].unique(),
locations_to_aggregate,
first_training_date,
national_geo_value="US",
)
Expand Down Expand Up @@ -205,9 +210,14 @@ def aggregate_facility_level_nssp_to_state(

if state_abb == "US":
logger.info("Aggregating facility-level data to national")
locations_to_aggregate = (
state_pop_df.filter(pl.col("abb") != "US")
.get_column("abb")
.unique()
)
facility_level_nssp_data = aggregate_to_national(
facility_level_nssp_data,
state_pop_df["abb"].unique(),
locations_to_aggregate,
first_training_date,
national_geo_value="US",
)
Expand Down Expand Up @@ -247,20 +257,12 @@ def verify_no_date_gaps(df: pl.DataFrame):


def get_state_pop_df():
census_dat = pl.read_csv(
"https://raw.githubusercontent.com/k5cents/usa/master/data-raw/facts.csv"
).rename({"name": "long_name"})

state_pop_df = forecasttools.location_table.join(
census_dat, on="long_name", how="right"
).select(
return forecasttools.location_table.select(
pl.col("short_name").alias("abb"),
pl.col("long_name").alias("name"),
pl.col("population"),
)

return state_pop_df


def get_pmfs(param_estimates: pl.LazyFrame, state_abb: str, disease: str):
generation_interval_pmf = (
Expand Down Expand Up @@ -327,14 +329,9 @@ def process_and_save_state(

state_pop_df = get_state_pop_df()

if state_abb == "US":
state_pop = state_pop_df["population"].sum()
else:
state_pop = (
state_pop_df.filter(pl.col("abb") == state_abb)
.get_column("population")
.to_list()[0]
)
state_pop = state_pop_df.filter(pl.col("abb") == state_abb).item(
0, "population"
)

(generation_interval_pmf, delay_pmf, right_truncation_pmf) = get_pmfs(
param_estimates=param_estimates, state_abb=state_abb, disease=disease
Expand Down
13 changes: 13 additions & 0 deletions pipelines/tests/test_prep_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pipelines import prep_data


def test_get_state_pop_df():
"""
Confirm get_state_pop_df()
returns a polars data frame
with the expected number of rows
and expected column names
"""
df = prep_data.get_state_pop_df()
assert df.height == 58 # 50 states, 7 other jursidictions, US national
assert set(df.columns) == set(["name", "abb", "population"])
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pypdf = "^5.1.0"
pyarrow = "^18.0.0"
pygit2 = "^1.17.0"
azuretools = {git = "https://github.com/cdcgov/cfa-azuretools"}
forecasttools = {git = "https://github.com/CDCgov/forecasttools-py"}
forecasttools = {git = "https://github.com/cdcgov/forecasttools-py"}
tomli-w = "^1.1.0"

[tool.poetry.group.test]
Expand Down

0 comments on commit 3a0a5bf

Please sign in to comment.