Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stan and PyRenew model comparison doc #9

Merged
merged 40 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d4e1984
add files from pyrenew PR
damonbayer Aug 22, 2024
6c4d0d9
change dirichlet prior
damonbayer Aug 23, 2024
c14db7f
Merge branch 'main' into hosp_only_ww_model
damonbayer Aug 26, 2024
34801cb
update linkes to ww-inference-model
damonbayer Aug 26, 2024
f68cae1
fix imports
damonbayer Aug 26, 2024
da067b9
fix function names and dirichlet variable import
damonbayer Aug 26, 2024
bccae36
actually use dirichlet distribution
damonbayer Aug 26, 2024
768d67b
fix stan data loading
damonbayer Aug 26, 2024
5f8b39e
quarto deps
damonbayer Aug 27, 2024
bb7685e
rework model for updated ar process
damonbayer Aug 27, 2024
6185d4a
let quarto demo render
damonbayer Aug 27, 2024
103ab10
make pre-commit happy
damonbayer Aug 27, 2024
ccbe486
recommitting broken model with fewer changes
damonbayer Aug 27, 2024
78ee4d3
add polars dependency
damonbayer Aug 27, 2024
5bf396b
fixed model
damonbayer Aug 27, 2024
b5bb785
do forecasting
damonbayer Aug 27, 2024
00139a2
use DifferencedProcess directly
damonbayer Aug 28, 2024
29d2721
Use compute_delay_ascertained_incidence
damonbayer Aug 28, 2024
1200281
add create_hosp_only_ww_model_from_stan_data
damonbayer Aug 28, 2024
57b7ee5
use built in transformations
damonbayer Aug 29, 2024
0f3c875
add plotting module
damonbayer Aug 29, 2024
119a4c6
use new functions in notebook
damonbayer Aug 29, 2024
4b6758b
cleanup comment
damonbayer Aug 29, 2024
1bacce0
refactor for predictive plotting
damonbayer Aug 29, 2024
c2b0512
clean up some comments
damonbayer Aug 29, 2024
ee86c53
Delete notebooks/hosp_only_ww_model.md
damonbayer Aug 29, 2024
830376b
clean up imports
damonbayer Aug 29, 2024
2f246a1
clean up posterior plots
damonbayer Aug 29, 2024
0f2436a
save data
damonbayer Aug 30, 2024
d5b941f
checkin
damonbayer Sep 3, 2024
516120c
Try reduced sigma_rt_prior
damonbayer Sep 3, 2024
32986c7
update gitignore
damonbayer Sep 4, 2024
4911aa2
Create model_comp
damonbayer Sep 4, 2024
48e75d2
Remove model_comp.R
damonbayer Sep 4, 2024
4f56fb5
add note about IHR length
damonbayer Sep 4, 2024
0767206
Add intro note
damonbayer Sep 4, 2024
e72427f
Merge branch 'main' into dmb_demo_model_comp
damonbayer Sep 4, 2024
58a0d82
Revert changes in notebooks/data/fit/stan_data.json
damonbayer Sep 5, 2024
a05435a
update renv lockfile
damonbayer Sep 5, 2024
3ec45c1
Add note about running other notebooks first
damonbayer Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,7 @@ docs/site/

.DS_Store
poetry.lock


notebooks/*_files/
notebooks/*.md
2 changes: 1 addition & 1 deletion notebooks/data/fit/stan_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@
"ww_site_mod_sd_sd": 0.25,
"inf_feedback_prior_logmean": 6.37408,
"inf_feedback_prior_logsd": 0.4,
"sigma_rt_prior": 0.1,
"sigma_rt_prior": 0.0001,
"log_phi_g_prior_mean": -2.302585,
"log_phi_g_prior_sd": 5,
"ww_sampled_sites": [1, 3, 4, 3, 2, 3, 1, 1, 2, 1, 1, 4, 1, 3, 1, 3, 1, 3, 4, 1, 2, 3, 1, 1, 1, 1, 4, 1, 3, 2, 1, 4, 2, 1, 1, 2, 3, 1, 2, 2, 3, 3, 1, 2, 1, 2, 1, 4, 1, 1, 2, 2, 3, 1, 4, 1, 1, 3, 4, 3, 4, 4, 1, 3, 1, 4, 1, 4, 2, 1, 1, 1, 1, 1, 3, 1, 1, 1, 3, 3, 3, 2, 3, 1, 2, 2, 1, 2, 1, 3, 3, 4, 1, 3, 4, 1, 3, 4],
Expand Down
2 changes: 1 addition & 1 deletion notebooks/data/fit_hosp_only/stan_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@
"ww_site_mod_sd_sd": 0.25,
"inf_feedback_prior_logmean": 6.37408,
"inf_feedback_prior_logsd": 0.4,
"sigma_rt_prior": 0.1,
"sigma_rt_prior": 0.0001,
"log_phi_g_prior_mean": -2.302585,
"log_phi_g_prior_sd": 5,
"ww_sampled_sites": [1, 3, 4, 3, 2, 3, 1, 1, 2, 1, 1, 4, 1, 3, 1, 3, 1, 3, 4, 1, 2, 3, 1, 1, 1, 1, 4, 1, 3, 2, 1, 4, 2, 1, 1, 2, 3, 1, 2, 2, 3, 3, 1, 2, 1, 2, 1, 4, 1, 1, 2, 2, 3, 1, 4, 1, 1, 3, 4, 3, 4, 4, 1, 3, 1, 4, 1, 4, 2, 1, 1, 1, 1, 1, 3, 1, 1, 1, 3, 3, 3, 2, 3, 1, 2, 2, 1, 2, 1, 3, 3, 4, 1, 3, 4, 1, 3, 4],
Expand Down
7 changes: 5 additions & 2 deletions notebooks/hosp_only_ww_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ my_hosp_only_ww_model, data_observed_hospital_admissions = (
We check that we can simulate from the prior predictive
```{python}
# | label: prior predictive
n_forecast_days = 28
n_forecast_days = 35

prior_predictive = my_hosp_only_ww_model.prior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days,
numpyro_predictive_args={"num_samples": 200},
Expand Down Expand Up @@ -117,5 +118,7 @@ for key in list(idata.posterior.keys()):
## Save for Post-Processing

```{python}
idata.to_dataframe().to_csv("data/fit_hosp_only/pyrenew_inference_data.csv")
idata.to_dataframe().to_csv(
"data/fit_hosp_only/pyrenew_inference_data.csv", index=False
)
```
158 changes: 158 additions & 0 deletions notebooks/model_comp.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
---
title: "PyRenew and wwinference Fit and Forecast Comparison"
format: gfm
editor: visual
---

This document shows graphical comparisons for key variables in the PyRenew model fit to example data (notebooks/hosp_only_ww_model.qmd) and Stan model notebooks/wwinference.Rmd (notebooks/wwinference.Rmd).

```{r}
#| output: false
library(tidyverse)
library(tidybayes)
library(fs)
library(cmdstanr)
library(posterior)
library(jsonlite)
library(scales)
library(here)
ci_width <- c(0.5, 0.8, 0.95)
```

## Load Data

```{r}
hosp_data <- tibble(.value = here(path("notebooks",
"data", "fit_hosp_only", "stan_data",
ext = "json"
)) |>
jsonlite::read_json() |>
pluck("hosp") |>
unlist()) |>
mutate(time = row_number())

stan_files <-
dir_ls(here(path("notebooks", "data", "fit_hosp_only")), glob = "*wwinference*") |>
enframe(name = NULL, value = "file_path") |>
mutate(file_details = path_ext_remove(path_file(file_path))) |>
separate_wider_delim(file_details,
delim = "-",
names = c("model", "date", "chain", "hash")
) |>
mutate(date = ymd_hm(date)) |>
filter(date == max(date)) |>
pull(file_path)


stan_tidy_draws <- read_cmdstan_csv(stan_files)$post_warmup_draws |> tidy_draws()

pyrenew_tidy_draws <- read_csv(here(path("notebooks",
"data", "fit_hosp_only", "pyrenew_inference_data",
ext = "csv"
))) |>
rename_with(\(varname) str_remove_all(varname, "\\(|\\)|\\'|(, \\d+)")) |>
rename(
.chain = chain,
.iteration = draw
) |>
mutate(across(c(.chain, .iteration), \(x) as.integer(x + 1))) |>
mutate(
.draw = tidybayes:::draw_from_chain_and_iteration_(.chain, .iteration),
.after = .iteration
) |>
pivot_longer(-starts_with("."),
names_sep = ", ",
names_to = c("distribution", "name")
) |>
{
\(x) split(x |> select(-distribution), f = as.factor(x$distribution))
}() |>
map(\(x) pivot_wider(x, names_from = name) |> tidy_draws())
```

## Calculate Credible Intervals for Plotting

```{r}
combined_ci_for_plotting <-
bind_rows(
pyrenew_tidy_draws$posterior_predictive |>
gather_draws(observed_hospital_admissions[time], rt[time], ihr[time]) |>
median_qi(.width = ci_width) |>
mutate(model = "pyrenew"),
stan_tidy_draws |>
gather_draws(pred_hosp[time], rt[time], p_hosp[time]) |>
mutate(.variable = case_when(
.variable == "pred_hosp" ~ "observed_hospital_admissions",
.variable == "p_hosp" ~ "ihr",
TRUE ~ .variable
)) |>
median_qi(.width = ci_width) |>
mutate(model = "stan")
)
```



## Hospital Admission Comparison

```{r}
combined_ci_for_plotting |>
filter(.variable == "observed_hospital_admissions") |>
ggplot(aes(time, .value)) +
facet_wrap(~model) +
geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") +
scale_fill_brewer(
name = "Credible Interval Width",
labels = ~ percent(as.numeric(.))
) +
geom_point(data = hosp_data) +
cowplot::theme_cowplot() +
ggtitle("Vignette Data Model Comparison") +
scale_y_continuous("Hospital Admissions") +
scale_x_continuous("Time") +
theme(legend.position = "bottom")
```



## Rt Comparions

```{r}
combined_ci_for_plotting |>
filter(.variable == "rt") |>
ggplot(aes(time, .value)) +
facet_wrap(~model) +
geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") +
scale_fill_brewer(
name = "Credible Interval Width",
labels = ~ percent(as.numeric(.))
) +
cowplot::theme_cowplot() +
ggtitle("Vignette Data Model Comparison") +
scale_y_log10("Rt", breaks = scales::log_breaks(n = 6)) +
scale_x_continuous("Time") +
theme(legend.position = "bottom") +
geom_hline(yintercept = 1, linetype = "dashed")
```


## IHR Comparison

```{r}
combined_ci_for_plotting |>
filter(.variable == "ihr") |>
ggplot(aes(time, .value)) +
facet_wrap(~model) +
geom_lineribbon(aes(ymin = .lower, ymax = .upper), color = "#08519c") +
scale_fill_brewer(
name = "Credible Interval Width",
labels = ~ percent(as.numeric(.))
) +
cowplot::theme_cowplot() +
ggtitle("Vignette Data Model Comparison") +
scale_y_log10("IHR (p_hosp)", breaks = scales::log_breaks(n = 6)) +
scale_x_continuous("Time") +
theme(legend.position = "bottom")
```

IHR lengths are different (Stan model generates an unnecessarily long version, see https://github.com/CDCgov/ww-inference-model/issues/43#issuecomment-2330269879)
1 change: 1 addition & 0 deletions notebooks/wwinference.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ params <- get_params(
package = "wwinference"
)
)
params$sigma_rt_prior <- 0.0001
```

## Wastewater data pre-processing
Expand Down
Loading