Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function To Cast InferenceData Into tidy_draws Format #36

Open
wants to merge 76 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
3865001
initial commit for this PR; begin skeleton experimentation file
AFg6K7h4fhy2 Oct 28, 2024
763355e
some unfinished experimentation code; priority status change high to …
AFg6K7h4fhy2 Oct 28, 2024
44e7fe2
add first semi-failed attempt at converting entire idata object to ti…
AFg6K7h4fhy2 Oct 30, 2024
31c7b72
add attempt at option 2
AFg6K7h4fhy2 Oct 30, 2024
9a87902
slightly modify spread draws example
AFg6K7h4fhy2 Nov 4, 2024
c632ae8
more minor changes to tidy draws notebook
AFg6K7h4fhy2 Nov 4, 2024
123ad51
light edits during DHM convo
AFg6K7h4fhy2 Nov 7, 2024
a3c2d17
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Nov 25, 2024
f44a6ee
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Nov 25, 2024
cb883e3
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Nov 25, 2024
df922d4
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Nov 26, 2024
21968be
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Dec 5, 2024
7dcd7d3
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Dec 9, 2024
718ba85
a DB conversion attempt
AFg6K7h4fhy2 Dec 12, 2024
7394d4d
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Dec 16, 2024
4a77d50
begin references file; create external program folder
AFg6K7h4fhy2 Dec 16, 2024
4c18634
add csv ignoring when converting to csv
AFg6K7h4fhy2 Dec 16, 2024
9be91b4
further edits; problem with dates
AFg6K7h4fhy2 Dec 16, 2024
0884c3d
partial attempt edits
AFg6K7h4fhy2 Dec 16, 2024
0dd9616
minor reconsideration of unpivot pathways
AFg6K7h4fhy2 Dec 16, 2024
2a0f2b9
minor update, tidydraws replicator
AFg6K7h4fhy2 Dec 18, 2024
c746e3b
minor cleaning of existing notebook code
AFg6K7h4fhy2 Dec 18, 2024
e4e8e0f
some of the reimplementation attempt
AFg6K7h4fhy2 Jan 16, 2025
c5ed832
testing of conversion between date intervals and dataframe
AFg6K7h4fhy2 Jan 17, 2025
e9e2aab
small further edit
AFg6K7h4fhy2 Jan 17, 2025
9a27187
add pyrenew inference file; some code mods
AFg6K7h4fhy2 Jan 17, 2025
cc20975
remove extraneous notebooks
AFg6K7h4fhy2 Jan 21, 2025
4a8a0ae
remove more idata to tidy code
AFg6K7h4fhy2 Jan 21, 2025
0c4301a
create tidy draws dataframe dict
AFg6K7h4fhy2 Jan 21, 2025
814f68c
error comment
AFg6K7h4fhy2 Jan 21, 2025
beba8f3
remove earlier code version
AFg6K7h4fhy2 Jan 22, 2025
b12ffc2
begin port of notebook into codebase
AFg6K7h4fhy2 Jan 22, 2025
d85f49c
update docstring
AFg6K7h4fhy2 Jan 22, 2025
85c0dbf
add some of the notebook
AFg6K7h4fhy2 Jan 22, 2025
786d2e4
some minor notebook edits
AFg6K7h4fhy2 Jan 22, 2025
4c4e6dd
further refine notebook
AFg6K7h4fhy2 Jan 22, 2025
e949796
update func name for clarity
AFg6K7h4fhy2 Jan 28, 2025
ab0beb6
remove historical scripts
AFg6K7h4fhy2 Jan 28, 2025
27e7ada
tidybayes API
AFg6K7h4fhy2 Jan 28, 2025
b06f6e6
unfinished edits addressing several comments
AFg6K7h4fhy2 Jan 28, 2025
42acd05
docstring change
AFg6K7h4fhy2 Jan 28, 2025
2ed4c29
more corrective edits
AFg6K7h4fhy2 Jan 28, 2025
0ad7c24
another corrective edit
AFg6K7h4fhy2 Jan 28, 2025
9452acd
check for invalid groups
AFg6K7h4fhy2 Jan 29, 2025
325aac4
some extranous code removal
AFg6K7h4fhy2 Jan 30, 2025
462d57b
add partial, in-need-of-edits tests
AFg6K7h4fhy2 Jan 30, 2025
dae049f
comment test to prevent fail; add nesting of groups
AFg6K7h4fhy2 Jan 30, 2025
5e4cd63
comment test to prevent fail; add nesting of groups
AFg6K7h4fhy2 Jan 30, 2025
5c28831
small update to showcase script
AFg6K7h4fhy2 Jan 30, 2025
d246a44
uncomment failing tests; run pre-commit
AFg6K7h4fhy2 Feb 3, 2025
496c65c
attempt fix pre-commit issues
AFg6K7h4fhy2 Feb 3, 2025
2c34af7
pre-commit edits
AFg6K7h4fhy2 Feb 3, 2025
dba6e7a
more prudent usage of ruff; pre-commit fixes
AFg6K7h4fhy2 Feb 4, 2025
e1cdbb9
revert version edits from black error
AFg6K7h4fhy2 Feb 4, 2025
9884aa9
use selectors; finally, nice to figure that out
AFg6K7h4fhy2 Feb 4, 2025
90f5484
fix some conflicts with daily to epiweekly
AFg6K7h4fhy2 Feb 4, 2025
683a7bf
Merge branch 'main' into 18-function-to-cast-inferencedata-into-tidy_…
AFg6K7h4fhy2 Feb 4, 2025
3547ab6
fix pre-commit errors
AFg6K7h4fhy2 Feb 4, 2025
2999ec3
add pivot to make life easier; not sure if to aggregate by first
AFg6K7h4fhy2 Feb 4, 2025
04d6a50
further debugging edits
AFg6K7h4fhy2 Feb 4, 2025
cdbe464
draws and iterations debugged
AFg6K7h4fhy2 Feb 4, 2025
987c273
revert versioning in test yaml
AFg6K7h4fhy2 Feb 4, 2025
8361c36
update location table; add united states data; update descriptions in…
AFg6K7h4fhy2 Feb 5, 2025
41a8136
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Feb 5, 2025
bb55bf4
remove extraneous united states parquet call
AFg6K7h4fhy2 Feb 5, 2025
1cdef9d
Update forecasttools/idata_to_tidy.py
AFg6K7h4fhy2 Feb 6, 2025
519c048
Update forecasttools/idata_to_tidy.py
AFg6K7h4fhy2 Feb 6, 2025
1019ac0
fix docstring; fix chain equation
AFg6K7h4fhy2 Feb 6, 2025
92c34bd
revert ensure listlike import
AFg6K7h4fhy2 Feb 6, 2025
b3f6367
remove tab ignoral
AFg6K7h4fhy2 Feb 6, 2025
84bb99e
switch from melt to pivot
AFg6K7h4fhy2 Feb 6, 2025
1a94da4
lightweight change to dev deps
AFg6K7h4fhy2 Feb 10, 2025
7a737f6
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Feb 10, 2025
740f9c8
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Feb 11, 2025
b3544e5
commented change; was examining results
AFg6K7h4fhy2 Feb 14, 2025
3c9609c
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 Feb 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added assets/external/inference_data_1.nc
Binary file not shown.
17 changes: 17 additions & 0 deletions assets/misc/references.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

# jax
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}

# numpyro
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
2 changes: 2 additions & 0 deletions forecasttools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import polars as pl

from forecasttools.daily_to_epiweekly import df_aggregate_to_epiweekly
from forecasttools.idata_to_tidy import convert_inference_data_to_tidydraws
from forecasttools.idata_w_dates_to_df import (
add_time_coords_to_idata_dimension,
add_time_coords_to_idata_dimensions,
Expand Down Expand Up @@ -99,4 +100,5 @@
"generate_time_range_for_dim",
"validate_iter_has_expected_types",
"ensure_listlike",
"convert_inference_data_to_tidydraws",
]
4 changes: 2 additions & 2 deletions forecasttools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def make_nshn_fitting_dataset(
)
.sort(["state", "date"])
)
df_covid.write_csv(file_save_path)
df_covid.write_parquet(file_save_path)
if dataset == "flu":
df_flu = (
df.select(
Expand All @@ -306,7 +306,7 @@ def make_nshn_fitting_dataset(
.rename({"previous_day_admission_influenza_confirmed": "hosp"})
.sort(["state", "date"])
)
df_flu.write_csv(file_save_path)
df_flu.write_parquet(file_save_path)
print(f"The file {file_save_path} has been created.")


Expand Down
95 changes: 95 additions & 0 deletions forecasttools/idata_to_tidy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Contains functions for interfacing between
the tidy-verse and arviz, which includes
the conversion of idata objects (and hence
their groups) in tidy-usable objects.
"""

import arviz as az
import polars as pl
import polars.selectors as cs


def convert_inference_data_to_tidydraws(
idata: az.InferenceData, groups: list[str]
) -> dict[str, pl.DataFrame]:
"""
Creates a dictionary of polars dataframes
from the groups of an arviz InferenceData
object for use with the tidybayes API.

Parameters
----------
idata : az.InferenceData
An InferenceData object.
groups : list[str]
A list of groups to transform to
tidy draws format. Defaults to all
groups in the InferenceData.

Returns
-------
dict[str, pl.DataFrame]
A dictionary of groups from the idata
for use with the tidybayes API.
"""
available_groups = list(idata.groups())
if groups is None:
groups = available_groups
else:
invalid_groups = [
group for group in groups if group not in available_groups
]
if invalid_groups:
raise ValueError(
f"Requested groups {invalid_groups} not found"
" in this InferenceData object."
f" Available groups: {available_groups}"
)

idata_df = pl.DataFrame(idata.to_dataframe())

tidy_dfs = {
group: (
idata_df.select("chain", "draw", cs.starts_with(f"('{group}',"))
.rename(
{
col: col.split(", ")[1].strip("')")
for col in idata_df.columns
if col.startswith(f"('{group}',")
}
)
# draw in arviz is iteration in tidybayes
.rename({"draw": ".iteration", "chain": ".chain"})
.unpivot(
index=[".chain", ".iteration"],
variable_name="variable",
value_name="value",
)
.with_columns(
pl.col("variable").str.replace(r"\[.*\]", "").alias("variable")
)
.with_columns(pl.col(".iteration") + 1, pl.col(".chain") + 1)
.with_columns(
(pl.col(".iteration").n_unique()).alias("draws_per_chain"),
)
.with_columns(
(
((pl.col(".chain") - 1) * pl.col("draws_per_chain"))
+ pl.col(".iteration")
).alias(".draw")
)
AFg6K7h4fhy2 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

@dylanhmorris dylanhmorris Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to drop "draws_per_chain", but also it's not a given that all chains will have the same number of draws. Instead, more robust do compute this as .iteration + <n_draws_in_all_previous_chains>. Many ways to do that in polars.

# .with_columns(
# pl.arange(1, pl.count() + 1).alias(".draw")
# )
.pivot(
values="value",
index=[".chain", ".iteration", ".draw"],
columns="variable",
aggregate_function="first",
)
.sort([".chain", ".iteration", ".draw"])
)
for group in groups
}
return tidy_dfs
Binary file added forecasttools/united_states.parquet
Binary file not shown.
160 changes: 160 additions & 0 deletions notebooks/pyrenew_dates_to_tidy.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
---
title: Add Dates To Pyrenew Idata And Use In Tidy-Verse
format: gfm
engine: jupyter
---

_The following notebook illustrates the addition of dates to an external `idata` object before demonstrating the tidy-usable capabilities in R._

__Load In Packages And External Pyrenew InferenceData Object__

```{python}
#| echo: true

import forecasttools
import arviz as az
import xarray as xr
import os
import polars as pl
import subprocess
import tempfile
from datetime import date, timedelta

xr.set_options(display_expand_data=False, display_expand_attrs=False)


pyrenew_idata_path = "../assets/external/inference_data_1.nc"
pyrenew_idata = az.from_netcdf(pyrenew_idata_path)
pyrenew_idata
```

__Define Groups To Save And Convert__


```{python}
#| echo: true

pyrenew_groups = ["posterior_predictive"]
tidy_usable_groups = forecasttools.convert_inference_data_to_tidydraws(
idata=pyrenew_idata,
groups=pyrenew_groups
)

# show output
tidy_usable_groups
```

```{python}
# TO DELETE
nested_tidy_dfs = {
group: {
var: df.select([".chain", ".iteration", ".draw", var])
for var in [col for col in df.columns if col not in [".chain", ".iteration", ".draw"]]
}
for group, df in tidy_usable_groups.items()
}

nested_tidy_dfs
```

__Demonstrate Adding Time To Pyrenew InferenceData__

```{python}
#| echo: true

start_time_as_dt = date(2022, 8, 1) # arbitrary

pyrenew_target_var = pyrenew_idata["posterior_predictive"]["observed_hospital_admissions"]
print(pyrenew_target_var)

pyrenew_var_w_dates = forecasttools.generate_time_range_for_dim(
start_time_as_dt=start_time_as_dt,
variable_data=pyrenew_target_var,
dimension="observed_hospital_admissions_dim_0",
time_step=timedelta(days=1),
)
print(pyrenew_var_w_dates[:5], type(pyrenew_var_w_dates[0]))
```

__Add Dates To Pyrenew InferenceData__

```{python}
#| echo: true

pyrenew_idata_w_dates = forecasttools.add_time_coords_to_idata_dimension(
idata=pyrenew_idata,
group="posterior_predictive",
variable="observed_hospital_admissions",
dimension="observed_hospital_admissions_dim_0",
start_date_iso=start_time_as_dt,
time_step=timedelta(days=1),
)

print(pyrenew_idata_w_dates["posterior_predictive"]["observed_hospital_admissions"]["observed_hospital_admissions_dim_0"])
pyrenew_idata_w_dates
```

__Again Convert The Dated Pyrenew InferenceData To Tidy-Usable__


```{python}

pyrenew_groups = ["posterior_predictive"]
tidy_usable_groups_w_dates = forecasttools.convert_inference_data_to_tidydraws(
idata=pyrenew_idata_w_dates,
groups=pyrenew_groups
)
tidy_usable_groups_w_dates
```

__Examine The Dataframe In The Tidyverse__

```{python}
def light_r_runner(r_code: str) -> None:
"""
Run R code from Python as a temp file.
"""
with tempfile.NamedTemporaryFile(suffix=".R", delete=False) as temp_r_file:
temp_r_file.write(r_code.encode("utf-8"))
temp_r_file_path = temp_r_file.name
try:
subprocess.run(["Rscript", temp_r_file_path], check=True)
except subprocess.CalledProcessError as e:
print(f"R script failed with error: {e}")
finally:
os.remove(temp_r_file_path)

#for key, tidy_df in tidy_usable_groups_w_dates.items():
# file_name = f"{key}.csv"
# if not os.path.exists(file_name):
# tidy_df.write_csv(file_name)
# print(f"Saved {file_name}")

for group, var in nested_tidy_dfs.items():
for var_name, tidy_data in var.items():
file_name = f"{var_name}.csv"
#if not os.path.exists(file_name):
tidy_data.write_csv(file_name)
print(f"Saved {file_name}")


r_code_to_verify_tibble = """
library(magrittr)
library(tidyverse)
library(tidybayes)

csv_files <- c("rt.csv")

for (csv_file in csv_files) {
tibble_data <- read_csv(csv_file)

print(paste("Tibble from", csv_file))
print(tibble_data)

tidy_data <- tibble_data %>%
tidybayes::tidy_draws()
print(tidy_data)
}
"""
light_r_runner(r_code_to_verify_tibble)
```
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ polars = "^1.8.2"
xarray = "^2024.9.0"
matplotlib = "^3.9.2"
epiweeks = "^2.3.0"
metaflow = "^2.13.9"
numpyro = "^0.17.0"
numpy = "^2.2.2"
tqdm = "^4.67.1"
Expand All @@ -44,6 +45,8 @@ patsy = "^0.5.6"
nbformat = "^5.10.4"
nbclient = "^0.10.0"
jupyter = "^1.1.1"
pandas = "^2.2.3"
metaflow = "^2.13.9"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want this?

jupyter-cache = "^1.0.1"


Expand Down
53 changes: 53 additions & 0 deletions tests/test_idata_to_tidy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import arviz as az
import numpy as np
import polars as pl
import pytest
import xarray as xr

import forecasttools


@pytest.fixture
def mock_inference_data():
np.random.seed(42)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running np.random.seed changes global state. Better practice to do something like this https://builtin.com/data-science/numpy-random-seed

posterior_predictive = xr.Dataset(
{
"observed_hospital_admissions": ("chain", np.random.randn(2, 100)),
},
coords={"chain": [0, 1]},
)
Comment on lines +13 to +18
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not run the test on the provided inference_data_1.nc? Or are you planning to remove it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am on the fence about removing it. Seems good to have a canonical pyrenew-hew .nc file on hand esp. given that forecasttools-py does / will even more so interface abundantly with pyrenew models. On the other hand, having adequate and general idata / xarray representations seems good for testing too. I do not know if the latter must exist at the cost of the former. I lean towards having both, with the .nc file perhaps being used in notebooks and the "fake" idatas being used for testing.


idata = az.from_dict(posterior_predictive=posterior_predictive)

return idata


def test_valid_conversion(mock_inference_data):
result = forecasttools.convert_inference_data_to_tidydraws(
mock_inference_data, ["posterior_predictive"]
)
assert isinstance(result, dict)
assert "posterior_predictive" in result
assert isinstance(result["posterior_predictive"], pl.DataFrame)

df = result["posterior_predictive"]
assert all(
col in df.columns
for col in [".chain", ".draw", ".iteration", "variable", "value"]
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to check that individual values are as expected, not just that the draws are unique and one for each row.

assert df[".draw"].n_unique() == df[".draw"].shape[0]


def test_invalid_group(mock_inference_data):
with pytest.raises(ValueError, match="Invalid groups provided"):
forecasttools.convert_inference_data_to_tidydraws(
mock_inference_data, ["invalid_group"]
)


def test_empty_group_list(mock_inference_data):
result = forecasttools.convert_inference_data_to_tidydraws(
mock_inference_data, []
)
assert result == {}
Loading