generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 2
Function To Cast InferenceData Into tidy_draws
Format
#36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AFg6K7h4fhy2
wants to merge
88
commits into
main
Choose a base branch
from
18-function-to-cast-inferencedata-into-tidy_draws-format
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
88 commits
Select commit
Hold shift + click to select a range
3865001
initial commit for this PR; begin skeleton experimentation file
AFg6K7h4fhy2 763355e
some unfinished experimentation code; priority status change high to …
AFg6K7h4fhy2 44e7fe2
add first semi-failed attempt at converting entire idata object to ti…
AFg6K7h4fhy2 31c7b72
add attempt at option 2
AFg6K7h4fhy2 9a87902
slightly modify spread draws example
AFg6K7h4fhy2 c632ae8
more minor changes to tidy draws notebook
AFg6K7h4fhy2 123ad51
light edits during DHM convo
AFg6K7h4fhy2 a3c2d17
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 f44a6ee
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 cb883e3
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 df922d4
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 21968be
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 7dcd7d3
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 718ba85
a DB conversion attempt
AFg6K7h4fhy2 7394d4d
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 4a77d50
begin references file; create external program folder
AFg6K7h4fhy2 4c18634
add csv ignoring when converting to csv
AFg6K7h4fhy2 9be91b4
further edits; problem with dates
AFg6K7h4fhy2 0884c3d
partial attempt edits
AFg6K7h4fhy2 0dd9616
minor reconsideration of unpivot pathways
AFg6K7h4fhy2 2a0f2b9
minor update, tidydraws replicator
AFg6K7h4fhy2 c746e3b
minor cleaning of existing notebook code
AFg6K7h4fhy2 e4e8e0f
some of the reimplementation attempt
AFg6K7h4fhy2 c5ed832
testing of conversion between date intervals and dataframe
AFg6K7h4fhy2 e9e2aab
small further edit
AFg6K7h4fhy2 9a27187
add pyrenew inference file; some code mods
AFg6K7h4fhy2 cc20975
remove extraneous notebooks
AFg6K7h4fhy2 4a8a0ae
remove more idata to tidy code
AFg6K7h4fhy2 0c4301a
create tidy draws dataframe dict
AFg6K7h4fhy2 814f68c
error comment
AFg6K7h4fhy2 beba8f3
remove earlier code version
AFg6K7h4fhy2 b12ffc2
begin port of notebook into codebase
AFg6K7h4fhy2 d85f49c
update docstring
AFg6K7h4fhy2 85c0dbf
add some of the notebook
AFg6K7h4fhy2 786d2e4
some minor notebook edits
AFg6K7h4fhy2 4c4e6dd
further refine notebook
AFg6K7h4fhy2 e949796
update func name for clarity
AFg6K7h4fhy2 ab0beb6
remove historical scripts
AFg6K7h4fhy2 27e7ada
tidybayes API
AFg6K7h4fhy2 b06f6e6
unfinished edits addressing several comments
AFg6K7h4fhy2 42acd05
docstring change
AFg6K7h4fhy2 2ed4c29
more corrective edits
AFg6K7h4fhy2 0ad7c24
another corrective edit
AFg6K7h4fhy2 9452acd
check for invalid groups
AFg6K7h4fhy2 325aac4
some extranous code removal
AFg6K7h4fhy2 462d57b
add partial, in-need-of-edits tests
AFg6K7h4fhy2 dae049f
comment test to prevent fail; add nesting of groups
AFg6K7h4fhy2 5e4cd63
comment test to prevent fail; add nesting of groups
AFg6K7h4fhy2 5c28831
small update to showcase script
AFg6K7h4fhy2 d246a44
uncomment failing tests; run pre-commit
AFg6K7h4fhy2 496c65c
attempt fix pre-commit issues
AFg6K7h4fhy2 2c34af7
pre-commit edits
AFg6K7h4fhy2 dba6e7a
more prudent usage of ruff; pre-commit fixes
AFg6K7h4fhy2 e1cdbb9
revert version edits from black error
AFg6K7h4fhy2 9884aa9
use selectors; finally, nice to figure that out
AFg6K7h4fhy2 90f5484
fix some conflicts with daily to epiweekly
AFg6K7h4fhy2 683a7bf
Merge branch 'main' into 18-function-to-cast-inferencedata-into-tidy_…
AFg6K7h4fhy2 3547ab6
fix pre-commit errors
AFg6K7h4fhy2 2999ec3
add pivot to make life easier; not sure if to aggregate by first
AFg6K7h4fhy2 04d6a50
further debugging edits
AFg6K7h4fhy2 cdbe464
draws and iterations debugged
AFg6K7h4fhy2 987c273
revert versioning in test yaml
AFg6K7h4fhy2 8361c36
update location table; add united states data; update descriptions in…
AFg6K7h4fhy2 41a8136
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 bb55bf4
remove extraneous united states parquet call
AFg6K7h4fhy2 1cdef9d
Update forecasttools/idata_to_tidy.py
AFg6K7h4fhy2 519c048
Update forecasttools/idata_to_tidy.py
AFg6K7h4fhy2 1019ac0
fix docstring; fix chain equation
AFg6K7h4fhy2 92c34bd
revert ensure listlike import
AFg6K7h4fhy2 b3f6367
remove tab ignoral
AFg6K7h4fhy2 84bb99e
switch from melt to pivot
AFg6K7h4fhy2 1a94da4
lightweight change to dev deps
AFg6K7h4fhy2 7a737f6
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 740f9c8
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 b3544e5
commented change; was examining results
AFg6K7h4fhy2 3c9609c
Merge remote-tracking branch 'origin/main' into 18-function-to-cast-i…
AFg6K7h4fhy2 ca7c340
have draw calculation take into account row count
AFg6K7h4fhy2 9b2cdc9
remove extraneous metaflow
AFg6K7h4fhy2 84ec6e1
some of the tests using simple idata class
AFg6K7h4fhy2 07e83a9
additional test
AFg6K7h4fhy2 238c80c
Update forecasttools/idata_to_tidy.py
AFg6K7h4fhy2 27436cd
revert to original aggregate function argument value
AFg6K7h4fhy2 70219ad
add base posterior predictive test for pyrenew idata
AFg6K7h4fhy2 60e0b5a
change test path; capture aggregate function error
AFg6K7h4fhy2 61849e0
change col search method from strip to re
AFg6K7h4fhy2 ba2847e
add lambda rename rather than dictionary comprehension
AFg6K7h4fhy2 15d920b
remove comment
AFg6K7h4fhy2 0604de1
update pre-commit config file to remove loop binding linting at reque…
AFg6K7h4fhy2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
|
||
# jax | ||
@software{jax2018github, | ||
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, | ||
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, | ||
url = {http://github.com/jax-ml/jax}, | ||
version = {0.3.13}, | ||
year = {2018}, | ||
} | ||
|
||
# numpyro | ||
@article{phan2019composable, | ||
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro}, | ||
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin}, | ||
journal={arXiv preprint arXiv:1912.11554}, | ||
year={2019} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" | ||
Contains functions for interfacing between | ||
the tidy-verse and arviz, which includes | ||
the conversion of idata objects (and hence | ||
their groups) in tidy-usable objects. | ||
""" | ||
|
||
import re | ||
|
||
import arviz as az | ||
import polars as pl | ||
import polars.selectors as cs | ||
|
||
|
||
def convert_inference_data_to_tidydraws( | ||
idata: az.InferenceData, groups: list[str] = None | ||
) -> dict[str, pl.DataFrame]: | ||
""" | ||
Creates a dictionary of polars dataframes | ||
from the groups of an arviz InferenceData | ||
object for use with the tidybayes API. | ||
|
||
Parameters | ||
---------- | ||
idata : az.InferenceData | ||
An InferenceData object. | ||
groups : list[str] | ||
A list of groups to transform to | ||
tidy draws format. Defaults to all | ||
groups in the InferenceData. | ||
|
||
Returns | ||
------- | ||
dict[str, pl.DataFrame] | ||
A dictionary of groups from the idata | ||
for use with the tidybayes API. | ||
""" | ||
available_groups = list(idata.groups()) | ||
if groups is None: | ||
groups = available_groups | ||
else: | ||
invalid_groups = [ | ||
group for group in groups if group not in available_groups | ||
] | ||
if invalid_groups: | ||
raise ValueError( | ||
f"Requested groups {invalid_groups} not found" | ||
" in this InferenceData object." | ||
f" Available groups: {available_groups}" | ||
) | ||
|
||
idata_df = pl.DataFrame(idata.to_dataframe()) | ||
|
||
tidy_dfs = { | ||
group: ( | ||
idata_df.select("chain", "draw", cs.starts_with(f"('{group}',")) | ||
.rename( | ||
lambda col, group=group: re.search( | ||
r",\s*'?(.+?)'?\)", col | ||
).group(1) | ||
if col.startswith(f"('{group}',") | ||
else col | ||
) | ||
# draw in arviz is iteration in tidybayes | ||
.rename({"draw": ".iteration", "chain": ".chain"}) | ||
.unpivot( | ||
index=[".chain", ".iteration"], | ||
variable_name="variable", | ||
value_name="value", | ||
) | ||
.with_columns( | ||
pl.col("variable").str.replace(r"\[.*\]", "").alias("variable") | ||
) | ||
.with_columns(pl.col(".iteration") + 1, pl.col(".chain") + 1) | ||
.pivot( | ||
values="value", | ||
index=[".chain", ".iteration"], | ||
columns="variable", | ||
aggregate_function=None, | ||
) | ||
.sort([".chain", ".iteration"]) | ||
.with_row_index(name=".draw", offset=1) | ||
) | ||
for group in groups | ||
} | ||
return tidy_dfs |
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
--- | ||
title: Add Dates To Pyrenew Idata And Use In Tidy-Verse | ||
format: gfm | ||
engine: jupyter | ||
--- | ||
|
||
_The following notebook illustrates the addition of dates to an external `idata` object before demonstrating the tidy-usable capabilities in R._ | ||
|
||
__Load In Packages And External Pyrenew InferenceData Object__ | ||
|
||
```{python} | ||
#| echo: true | ||
|
||
import forecasttools | ||
import arviz as az | ||
import xarray as xr | ||
import os | ||
import polars as pl | ||
import subprocess | ||
import tempfile | ||
from datetime import date, timedelta | ||
|
||
xr.set_options(display_expand_data=False, display_expand_attrs=False) | ||
|
||
|
||
pyrenew_idata_path = "../assets/external/inference_data_1.nc" | ||
pyrenew_idata = az.from_netcdf(pyrenew_idata_path) | ||
pyrenew_idata | ||
``` | ||
|
||
__Define Groups To Save And Convert__ | ||
|
||
|
||
```{python} | ||
#| echo: true | ||
|
||
pyrenew_groups = ["posterior_predictive"] | ||
tidy_usable_groups = forecasttools.convert_inference_data_to_tidydraws( | ||
idata=pyrenew_idata, | ||
groups=pyrenew_groups | ||
) | ||
|
||
# show output | ||
tidy_usable_groups | ||
``` | ||
|
||
```{python} | ||
# TO DELETE | ||
nested_tidy_dfs = { | ||
group: { | ||
var: df.select([".chain", ".iteration", ".draw", var]) | ||
for var in [col for col in df.columns if col not in [".chain", ".iteration", ".draw"]] | ||
} | ||
for group, df in tidy_usable_groups.items() | ||
} | ||
|
||
nested_tidy_dfs | ||
``` | ||
|
||
__Demonstrate Adding Time To Pyrenew InferenceData__ | ||
|
||
```{python} | ||
#| echo: true | ||
|
||
start_time_as_dt = date(2022, 8, 1) # arbitrary | ||
|
||
pyrenew_target_var = pyrenew_idata["posterior_predictive"]["observed_hospital_admissions"] | ||
print(pyrenew_target_var) | ||
|
||
pyrenew_var_w_dates = forecasttools.generate_time_range_for_dim( | ||
start_time_as_dt=start_time_as_dt, | ||
variable_data=pyrenew_target_var, | ||
dimension="observed_hospital_admissions_dim_0", | ||
time_step=timedelta(days=1), | ||
) | ||
print(pyrenew_var_w_dates[:5], type(pyrenew_var_w_dates[0])) | ||
``` | ||
|
||
__Add Dates To Pyrenew InferenceData__ | ||
|
||
```{python} | ||
#| echo: true | ||
|
||
pyrenew_idata_w_dates = forecasttools.add_time_coords_to_idata_dimension( | ||
idata=pyrenew_idata, | ||
group="posterior_predictive", | ||
variable="observed_hospital_admissions", | ||
dimension="observed_hospital_admissions_dim_0", | ||
start_date_iso=start_time_as_dt, | ||
time_step=timedelta(days=1), | ||
) | ||
|
||
print(pyrenew_idata_w_dates["posterior_predictive"]["observed_hospital_admissions"]["observed_hospital_admissions_dim_0"]) | ||
pyrenew_idata_w_dates | ||
``` | ||
|
||
__Again Convert The Dated Pyrenew InferenceData To Tidy-Usable__ | ||
|
||
|
||
```{python} | ||
|
||
pyrenew_groups = ["posterior_predictive"] | ||
tidy_usable_groups_w_dates = forecasttools.convert_inference_data_to_tidydraws( | ||
idata=pyrenew_idata_w_dates, | ||
groups=pyrenew_groups | ||
) | ||
tidy_usable_groups_w_dates | ||
``` | ||
|
||
__Examine The Dataframe In The Tidyverse__ | ||
|
||
```{python} | ||
def light_r_runner(r_code: str) -> None: | ||
""" | ||
Run R code from Python as a temp file. | ||
""" | ||
with tempfile.NamedTemporaryFile(suffix=".R", delete=False) as temp_r_file: | ||
temp_r_file.write(r_code.encode("utf-8")) | ||
temp_r_file_path = temp_r_file.name | ||
try: | ||
subprocess.run(["Rscript", temp_r_file_path], check=True) | ||
except subprocess.CalledProcessError as e: | ||
print(f"R script failed with error: {e}") | ||
finally: | ||
os.remove(temp_r_file_path) | ||
|
||
#for key, tidy_df in tidy_usable_groups_w_dates.items(): | ||
# file_name = f"{key}.csv" | ||
# if not os.path.exists(file_name): | ||
# tidy_df.write_csv(file_name) | ||
# print(f"Saved {file_name}") | ||
|
||
for group, var in nested_tidy_dfs.items(): | ||
for var_name, tidy_data in var.items(): | ||
file_name = f"{var_name}.csv" | ||
#if not os.path.exists(file_name): | ||
tidy_data.write_csv(file_name) | ||
print(f"Saved {file_name}") | ||
|
||
|
||
r_code_to_verify_tibble = """ | ||
library(magrittr) | ||
library(tidyverse) | ||
library(tidybayes) | ||
|
||
csv_files <- c("rt.csv") | ||
|
||
for (csv_file in csv_files) { | ||
tibble_data <- read_csv(csv_file) | ||
|
||
print(paste("Tibble from", csv_file)) | ||
print(tibble_data) | ||
|
||
tidy_data <- tibble_data %>% | ||
tidybayes::tidy_draws() | ||
print(tidy_data) | ||
} | ||
""" | ||
light_r_runner(r_code_to_verify_tibble) | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.