Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hosp_only model and demo #4

Merged
merged 30 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d4e1984
add files from pyrenew PR
damonbayer Aug 22, 2024
6c4d0d9
change dirichlet prior
damonbayer Aug 23, 2024
c14db7f
Merge branch 'main' into hosp_only_ww_model
damonbayer Aug 26, 2024
34801cb
update linkes to ww-inference-model
damonbayer Aug 26, 2024
f68cae1
fix imports
damonbayer Aug 26, 2024
da067b9
fix function names and dirichlet variable import
damonbayer Aug 26, 2024
bccae36
actually use dirichlet distribution
damonbayer Aug 26, 2024
768d67b
fix stan data loading
damonbayer Aug 26, 2024
5f8b39e
quarto deps
damonbayer Aug 27, 2024
bb7685e
rework model for updated ar process
damonbayer Aug 27, 2024
6185d4a
let quarto demo render
damonbayer Aug 27, 2024
103ab10
make pre-commit happy
damonbayer Aug 27, 2024
ccbe486
recommitting broken model with fewer changes
damonbayer Aug 27, 2024
78ee4d3
add polars dependency
damonbayer Aug 27, 2024
5bf396b
fixed model
damonbayer Aug 27, 2024
b5bb785
do forecasting
damonbayer Aug 27, 2024
00139a2
use DifferencedProcess directly
damonbayer Aug 28, 2024
29d2721
Use compute_delay_ascertained_incidence
damonbayer Aug 28, 2024
1200281
add create_hosp_only_ww_model_from_stan_data
damonbayer Aug 28, 2024
57b7ee5
use built in transformations
damonbayer Aug 29, 2024
0f3c875
add plotting module
damonbayer Aug 29, 2024
119a4c6
use new functions in notebook
damonbayer Aug 29, 2024
4b6758b
cleanup comment
damonbayer Aug 29, 2024
1bacce0
refactor for predictive plotting
damonbayer Aug 29, 2024
c2b0512
clean up some comments
damonbayer Aug 29, 2024
ee86c53
Delete notebooks/hosp_only_ww_model.md
damonbayer Aug 29, 2024
830376b
clean up imports
damonbayer Aug 29, 2024
2f246a1
clean up posterior plots
damonbayer Aug 29, 2024
0f2436a
save data
damonbayer Aug 30, 2024
5ada83a
respond to reviewer comments
damonbayer Sep 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions notebooks/hosp_only_ww_model.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
---
title: "Replicating Hospital Only Model from ww-inference-model"
format: gfm
engine: jupyter
---

```{python}
# | label: setup
import jax
import numpyro
import arviz as az
import pyrenew_covid_wastewater.plotting as plotting
from pyrenew_covid_wastewater.hosp_only_ww_model import (
create_hosp_only_ww_model_from_stan_data,
)

numpyro.set_host_device_count(4)
```

## Background

This tutorial provides a demonstration of our reimplementation of "Model 2" from the [`ww-inference-model` project](https://github.com/CDCgov/ww-inference-model).
The model is described [here](https://github.com/CDCgov/ww-inference-model/blob/main/model_definition.md).
Stan code for the model is [here](https://github.com/CDCgov/ww-inference-model/blob/main/inst/stan/wwinference.stan).

The model we provide is designed to be fully-compatible with the stan_data generated in the that project.
We provide the stan data used in the `wwinference` [vignette](https://github.com/CDCgov/ww-inference-model/blob/main/vignettes/wwinference.Rmd) in the [`ww-inference-model` project](https://github.com/CDCgov/ww-inference-model).
The data is available in `notebooks/data/fit_hosp_only/stan_data.json`.
This data was generated by running `notebooks/wwinference.Rmd`, which replicates the original vignette and saves the relevant data.
This script also saves the posterior samples from the model for comparison to our own model.

## Load Data and Create Priors

We begin by loading the Stan data, converting it the correct inputs for our model, and definitng the model.

```{python}
# | label: create model
my_hosp_only_ww_model, data_observed_hospital_admissions = (
create_hosp_only_ww_model_from_stan_data(
"data/fit_hosp_only/stan_data.json"
)
)
```

# Simulate from the model

We check that we can simulate from the prior predictive
```{python}
# | label: prior predictive
n_forecast_days = 28
prior_predictive = my_hosp_only_ww_model.prior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days,
numpyro_predictive_args={"num_samples": 200},
)
```

# Fit the model

Now we can fit the model to the observed data:
```{python}
# | label: fit the model
my_hosp_only_ww_model.run(
num_warmup=500,
num_samples=500,
rng_key=jax.random.key(200),
data_observed_hospital_admissions=data_observed_hospital_admissions,
mcmc_args=dict(num_chains=4, progress_bar=False),
nuts_args=dict(find_heuristic_step_size=True),
)
```

Create the posterior predictive and forecast:

```{python}
# | label: posterior predictive
posterior_predictive = my_hosp_only_ww_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + n_forecast_days
)
```

## Prepare for plotting

```{python}
# | label: prepare for plotting
import arviz as az

idata = az.from_numpyro(
my_hosp_only_ww_model.mcmc,
posterior_predictive=posterior_predictive,
prior=prior_predictive,
)
```

## Plot Predictive Distributions

```{python}
# | label: plot prior preditive
plotting.plot_predictive(idata, prior=True)
```

```{python}
# | label: plot posterior preditive
plotting.plot_predictive(idata)
```

## Plot all posteriors

```{python}
# | label: plot all posteriors
for key in list(idata.posterior.keys()):
try:
plotting.plot_posterior(idata, key)
except Exception as e:
print(f"An error occurred while plotting {key}: {e}")
```

## Save for Post-Processing

```{python}
idata.to_dataframe().to_csv("data/fit_hosp_only/pyrenew_inference_data.csv")
```
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.12"
pyrenew = {git = "https://github.com/CDCgov/PyRenew"}
ipywidgets = "^8.1.5"
arviz = "^0.19.0"
pyyaml = "^6.0.2"
jupyter = "^1.0.0"
ipykernel = "^6.29.5"
polars = "^1.5.0"


[build-system]
Expand Down
16 changes: 16 additions & 0 deletions pyrenew-covid-wastewater.Rproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Version: 1.0

RestoreWorkspace: Default
SaveWorkspace: Default
AlwaysSaveHistory: Default

EnableCodeIndexing: Yes
UseSpacesForTab: Yes
NumSpacesForTab: 2
Encoding: UTF-8

RnwWeave: Sweave
LaTeX: pdfLaTeX

AutoAppendNewline: Yes
StripTrailingWhitespace: Yes
Loading
Loading