Skip to content

Commit

Permalink
Add functionality to generate a tsv of observed data across all locat…
Browse files Browse the repository at this point in the history
…ions, including percentages (#74)
  • Loading branch information
dylanhmorris authored Oct 25, 2024
1 parent c8b0d41 commit 21575cf
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions nssp_demo/pull_state_timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import argparse
import datetime
import logging
from pathlib import Path

import polars as pl


def main(
nssp_data_dir,
output_path,
report_date: str | datetime.date,
first_date_to_pull: str | datetime.date = None,
separator="\t",
diseases=["covid", "influenza", "rsv"],
):
diseases_to_column_names = dict(
covid="COVID-19/Omicron",
influenza="Influenza",
rsv="RSV",
total="Total",
)

diseases_to_pull = [
diseases_to_column_names.get(disease) for disease in diseases
]

col_names_to_pull = diseases_to_pull + ["Total"]

if isinstance(report_date, str):
if report_date == "latest":
report_date = max(
f.stem for f in Path(nssp_data_dir).glob("*.parquet")
)
report_date = datetime.datetime.strptime(
report_date, "%Y-%m-%d"
).date()
elif not isinstance(report_date, datetime.date):
raise ValueError(
"`report_date` must be either be a "
"a `datetime.date` object, or a string "
"giving a date in IS08601 format."
)

if first_date_to_pull is None:
first_date_to_pull = pl.col("reference_date").min()
elif isinstance(first_date_to_pull, str):
first_date_to_pull = datetime.datetime.strptime(
first_date_to_pull, "%Y-%m-%d"
).date()
elif not isinstance(first_date_to_pull, datetime.date):
raise ValueError(
"`first_date_to_pull` must be `None` "
"in which case all available dates are pulled, ",
"a `datetime.date` object, or a string "
"giving a date in IS08601 format.",
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

logger.info(f"Report date: {report_date}")

datafile = f"{report_date}.parquet"
nssp_data = pl.scan_parquet(Path(nssp_data_dir, datafile))

data = (
nssp_data.filter(
pl.col("disease").is_in(col_names_to_pull),
pl.col("metric") == "count_ed_visits",
pl.col("reference_date") > first_date_to_pull,
pl.col("report_date") == report_date,
)
.select(["reference_date", "geo_value", "disease", "value"])
.group_by(["reference_date", "geo_value", "disease"])
.agg(value=pl.col("value").sum())
.sort(["reference_date", "geo_value"])
.collect()
.pivot(on="disease", index=["reference_date", "geo_value"])
.rename(
{
v: f"count_{k}"
for k, v in diseases_to_column_names.items()
if v in col_names_to_pull
}
)
.with_columns(
**{
f"frac_{x}": (pl.col(f"count_{x}") / pl.col("count_total"))
for x in diseases
}
)
.with_columns(
**{f"pct_{x}": (100.0 * pl.col(f"frac_{x}")) for x in diseases}
)
.select(
[
pl.col("reference_date").alias("date"),
pl.col("geo_value").alias("location"),
]
+ [
item
for x in diseases
for item in [f"count_{x}", f"frac_{x}", f"pct_{x}"]
]
+ ["count_total"]
)
)

print(data)

logger.info(f"Saving data to {output_path}.")

data.write_csv(file=output_path, separator=separator)

logger.info("Data preparation complete.")


parser = argparse.ArgumentParser(
description="Pull NSSP data across pathogens."
)
parser.add_argument(
"nssp_data_dir",
type=Path,
help=(
"Directory in which to look for NSSP data gold table "
".parquet files."
),
)
parser.add_argument(
"output_path",
type=Path,
help="Path to which to save the output file, as a tsv.",
)

parser.add_argument(
"--report-date",
type=str,
default="latest",
help="Report date in YYYY-MM-DD format or latest (default: latest)",
)


if __name__ == "__main__":
args = parser.parse_args()
main(**vars(args))

0 comments on commit 21575cf

Please sign in to comment.