Skip to content

Commit

Permalink
Add MCMC output numerics to manuscript and appendix, plus traceplots …
Browse files Browse the repository at this point in the history
…for worst parameters #55
  • Loading branch information
athowes committed Jul 18, 2023
1 parent c439fbc commit f656536
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 99 deletions.
Binary file modified docs/appendix.pdf
Binary file not shown.
226 changes: 155 additions & 71 deletions docs/mcmc-convergence.html

Large diffs are not rendered by default.

Binary file modified docs/paper.pdf
Binary file not shown.
20 changes: 15 additions & 5 deletions src/docs_paper/appendix.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -550,12 +550,18 @@ ks_summary %>%

# MCMC convergence and suitability

```{r}
mcmc_out <- readRDS("depends/mcmc-out.rds")
```

We assessed MCMC convergence and suitability using a range of graphical and numerical tests.
All potential scale reduction factor $\hat R$ statistics [@vehtari2021rank] were below 1.05 and therefore acceptable (Figure \ref{fig:rhat}).
However, even thinning by a factor of 20, samples were not obtained very efficiently, resulting in the majority of effective sample size (ESS) ratios being below 0.5, with some as low as 0.1 (Figure \ref{fig:ratio}).
As a result, the number of obtained ESS varied substantially by parameter (Figure \ref{fig:ess}).
The largest scale reduction factor $\hat R$ was `r signif(mcmc_out$max_rhat, 3)` (Figure \ref{fig:rhat}).
As such, all values of $\hat R < 1.05$.
Even thinning by a factor of 20, samples were not obtained very efficiently, resulting in the majority of effective sample size (ESS) ratios being below 0.5, with some as low as 0.1 (Figure \ref{fig:ratio}).
As a result, the number of obtained ESS varied substantially by parameter (Figure \ref{fig:ess}): the minimum was `r round(mcmc_out$ess_min, 1)`, 2.5% quantile was `r round(mcmc_out$ess_lower, 1)`, median was `r round(mcmc_out$ess_median, 1)`, 97.5% quantile was `r round(mcmc_out$ess_upper, 1)` and maximum was `r round(mcmc_out$ess_max, 1)`.
Traceplots for the parameters with the lowest ESS (`log_sigma_alpha_xs`) and highest $\hat R$ (the 10th index of `ui_lambda_x`) are shown in Figure \ref{fig:worst_trace}.
There were no divergent transitions.
Could add energy plot from @betancourt2017conceptual here.
<!-- Could add energy plot from @betancourt2017conceptual here. -->

```{r}
np <- readRDS("depends/nuts-params.rds")
Expand All @@ -580,11 +586,15 @@ knitr::include_graphics("depends/ratio.png")
knitr::include_graphics("depends/ess.png")
```

```{r worst_trace, fig.cap="Traceplots of the parameters with the lowest ESS and highest scale reduction factor."}
knitr::include_graphics("depends/worst-trace.png")
```

```{r rho_a, fig.cap="Variation between units can be explained either by high correlation and high variance, or by low correlation and low variance. As such, the AR1 hyperparameters had correlated posteriors. It is this correlation which we made use of with PCA-AGHQ."}
knitr::include_graphics("depends/rho_a.png")
```

```{r rho_a, fig.cap="In contrast to the AR1 hyperparameters, the BYM2 hyperparameter posteriors are much less correlated."}
```{r alpha_x, fig.cap="In contrast to the AR1 hyperparameters, the BYM2 hyperparameter posteriors are much less correlated."}
knitr::include_graphics("depends/alpha_x.png")
```

Expand Down
10 changes: 9 additions & 1 deletion src/docs_paper/figures.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ library(patchwork)

cols <- c("#56B4E9","#009E73", "#E69F00", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")

#' Fig A

TMB::compile("2d.cpp")
dyn.load(TMB::dynlib("2d"))

Expand Down Expand Up @@ -106,6 +108,12 @@ ggsave("figA.png", h = 2.5, w = 6.25, bg = "white")

#' Fig B

# aghq <- readRDS("depends/aghq.rds")

#' TODO Generate this figure here: requires having outputs for AGHQ

#' Fig C

df_compare <- readRDS("depends/beta_alpha.rds")

mean <- df_compare %>%
Expand Down Expand Up @@ -173,4 +181,4 @@ ecdf_diff <- ggplot(ecdf_df, aes(x = x, y = ecdf_diff, col = method)) +

histogram + ecdf_diff

ggsave("figB.png", h = 4, w = 6.25, background = "white")
ggsave("figC.png", h = 4, w = 6.25, background = "white")
2 changes: 2 additions & 0 deletions src/docs_paper/orderly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ depends:
depends/rhat.png: rhat.png
depends/ratio.png: ratio.png
depends/ess.png: ess.png
depends/mcmc-out.rds: out.rds
depends/rho_a.png: rho_a.png
depends/alpha_x.png: alpha_x.png
depends/ar1-bym2-cor.csv: ar1-bym2-cor.csv
depends/worst-trace.png: worst-trace.png
depends/nuts-params.rds: nuts-params.rds
# - naomi-simple_fit:
# id: latest(parameter:aghq == TRUE && parameter:k == 3 && parameter:s == 8)
Expand Down
30 changes: 19 additions & 11 deletions src/docs_paper/paper.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ Method & Software & Details \\
\midrule
TMB & \texttt{TMB} & $1000$ samples \\
PCA-AGHQ & \texttt{aghq} & $k = 3, s = 8$ (see Section \ref{sec:pca-use}), $1000$ samples \\
NUTS & \texttt{tmbstan} & $4$ chains of $20000$ iterations, with the first $10000$ iterations of each chain discarded as warmup, thinned by a factor of $20$. Default NUTS tuning parameters \citep{hoffman2014no}. \\
NUTS & \texttt{tmbstan} & $4$ chains of $100,000$ iterations, with the first $50,000$ iterations of each chain discarded as warmup, thinned by a factor of $40$, to give a total of $5000$ samples kept. Default NUTS tuning parameters \citep{hoffman2014no}. \\
\bottomrule
\end{tabularx}
\caption{A summary of settings used for each inferential method.}
Expand All @@ -422,11 +422,15 @@ NUTS & \texttt{tmbstan} & $4$ chains of $20000$ iterations, with the first $1000

## NUTS convergence

We increased the chain lengths until all NUTS diagnostics were acceptable.
This required chains of length 20,000, which we thinned by a factor of 20 for ease-of-storage.
All potential scale reduction factors [@gelman1992inference; @vehtari2021rank] were $\hat R < 1.05$ (Figure S7).
```{r}
mcmc_out <- readRDS("depends/mcmc-out.rds")
# signif(mcmc_out$max_rhat, 4)
```

Obtaining acceptable NUTS diagnostics required four chains run in parallel for 100,000 iterations, thinned by a factor of 20 for ease-of-storage.
The largest potential scale reduction factor [@gelman1992inference; @vehtari2021rank] was $\hat R = 1.021$ (Figure S7), and there were no divergent transitions.
We considered the NUTS results a gold-standard, though inaccuracies remain possible.
For full details see Appendix S2.
Though inaccuracies remain possible, we considered the NUTS results to be a gold-standard.

## Use of PCA-AGHQ \label{sec:pca-use}

Expand All @@ -441,7 +445,7 @@ knitr::include_graphics("depends/nodes-samples-comparison.png")
### Visual inspection

Overlaying the resulting $3^8 = 6561$ PCA-AGHQ nodes onto the hyperparameter marginal posteriors obtained using NUTS, we found approximately 12 of the 24 hyperparameters had well covered marginals (Figure \ref{fig:node-positions}).
Though 12 is an improvement on the 8 that would be naively achieved using a dense grid, there remained many hyperparameters poorly covered.
Though 12 improves on the 8 naively achievable using a dense grid, there remained many hyperparameters poorly covered.
Coverage was associated with marginal standard deviation (Figure S4), which varied particularly according to the hyperparameter scale.
All constrained hyperparameters $\theta$ were transformed to the real line, using either a log ($\theta > 0$) or logit ($\theta \in [0, 1]$) transformation.
As a result, marginal standard deviations for log transformed hyperparameters were systematically smaller than those which were logit transformed (Figure S5).
Expand Down Expand Up @@ -475,7 +479,11 @@ Appendix S3 shows those values which we could compute in a reasonable time (less

### Posterior contraction

To assess the informativeness of the data we compared the prior variance $\sigma_\text{prior}^2(\psi)$ to the posterior variance $\sigma_\text{posterior}^2(\psi)$ via the posterior contraction $c(\psi) = 1 - (\sigma_\text{posterior}^2(\psi) / \sigma_\text{prior}^2(\psi))$ [@schad2021toward], where $\psi$ is a model parameter.
To assess the informativeness of the data we compared the prior variance $\sigma_\text{prior}^2(\psi)$ to the posterior variance $\sigma_\text{posterior}^2(\psi)$ via the posterior contraction [@schad2021toward]
\begin{equation}
c(\psi) = 1 - (\sigma_\text{posterior}^2(\psi) / \sigma_\text{prior}^2(\psi)),
\end{equation}
where $\psi$ is a model parameter.
We found that (Figure \ref{fig:contraction})) something something.
For greater interpretability, facet parameters in this plot according to model component.

Expand Down Expand Up @@ -539,12 +547,12 @@ Results for the PSIS analysis are pending.

### Maximum mean discrepancy

Let $\Psi = \{\bpsi_i\}_{i = 1}^n$ and $\Psi = \{\bpsi_i\}_{i = 1}^n$ be two sets of joint posterior samples, and $k$ be a kernel.
Let $\Psi^{1} = \{\bpsi^1_i\}_{i = 1}^n$ and $\Psi^2 = \{\bpsi^2_i\}_{i = 1}^n$ be two sets of joint posterior samples, and $k$ be a kernel.
The maximum mean discrepancy [MMD; @gretton2006kernel] can be empirically estimated by
\begin{equation*}
\text{MMD}(\Psi, \Psi) = \sqrt{\frac{1}{n^2} \sum_{i, j = 1}^n k(\bpsi_i, \bpsi_j) - \frac{2}{n^2} \sum_{i, j = 1}^n k(\bpsi_i, \bpsi_j) + \frac{1}{n^2} \sum_{i, j = 1}^n k(\bpsi_i, \bpsi_j)}.
\text{MMD}(\Psi^1, \Psi^2) = \sqrt{\frac{1}{n^2} \sum_{i, j = 1}^n k(\bpsi^1_i, \bpsi^1_j) - \frac{2}{n^2} \sum_{i, j = 1}^n k(\bpsi_i^1, \bpsi_j^2) + \frac{1}{n^2} \sum_{i, j = 1}^n k(\bpsi^2_i, \bpsi^2_j)}.
\end{equation*}
We set $k(\bpsi_i, \bpsi_j) = \exp(-\sigma \lVert \bpsi_i - \bpsi_j \rVert^2)$ with $\sigma$ estimated from data using the `kernlab` \textsc{R} package [@karatzoglou2019package].
We set $k(\bpsi^1, \bpsi^2) = \exp(-\sigma \lVert \bpsi^1 - \bpsi^2 \rVert^2)$ with $\sigma$ estimated from data using the `kernlab` \textsc{R} package [@karatzoglou2019package].
As compared with NUTS, the MMD from PCA-AGHQ (0.071) was 11% smaller than that of TMB (0.080).

<!-- This should be calculated inline! -->
Expand All @@ -570,7 +578,7 @@ Naomi can be used to assess the probability of a strata having high incidence by
We found that both TMB and PCA-AGHQ overestimate these exceedance probabilities (Figure \ref{fig:exceedance}, second row).
This is surprising, in that we expect inferences from NUTS to be more heavy-tailed than those from TMB or PCA-AGHQ.

```{r exceedance, fig.cap="Though PCA-AGHQ does perform slightly better, both approximate inference methods are meaningfully inaccurate as compared with NUTS for estimating exceedance probabilities. For the second 90 target the inaccuracy varies substantially by sex."}
```{r exceedance, fig.cap="Though PCA-AGHQ was marginally better, both approximate inference methods were meaningfully inaccurate as compared with NUTS for estimating exceedance probabilities. For the second 90 target the inaccuracy varied substantially by sex."}
knitr::include_graphics("depends/exceedance.png")
```

Expand Down
64 changes: 53 additions & 11 deletions src/naomi-simple_mcmc/mcmc-convergence.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,32 @@ We start by obtaining results from the latest version of `naomi-simple_fit` with
```{r}
out <- readRDS("depends/out.rds")
mcmc <- out$mcmc$stanfit
depends <- yaml::read_yaml("orderly.yml")$depends
dependency_details <- function(i) {
report_name <- names(depends[[i]])
print(paste0("Inference results obtained from ", report_name, " with the query ", depends[[i]][[report_name]]$id))
report_id <- orderly::orderly_search(query = depends[[i]][[report_name]]$id, report_name)
print(paste0("Obtained report had ID ", report_id, " and was run with the following parameters:"))
print(orderly::orderly_info(report_id, report_name)$parameters)
}
dependency_details(1)
```

This MCMC took `r round(out$time, 3)` days to run

```{r}
cbpalette <- c("#56B4E9", "#009E73", "#E69F00", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")
bayesplot::color_scheme_set("viridisA")
color_scheme_set("viridis")
ggplot2::theme_set(theme_minimal())
```

# $\hat R$

We are looking for values of $\hat R$ less than 1.1 here.
We are looking for values of $\hat R$ less than 1.05 here.

```{r}
rhats <- bayesplot::rhat(mcmc)
Expand All @@ -55,8 +67,10 @@ bayesplot::mcmc_rhat_data(rhats) %>%
ggsave("rhat.png", h = 3, w = 6.25)
(big_rhats <- rhats[rhats > 1.1])
(big_rhats <- rhats[rhats > 1.05])
length(big_rhats) / length(rhats)
(max_rhat <- max(rhats))
```

# ESS ratio
Expand Down Expand Up @@ -84,6 +98,8 @@ bayesplot::mcmc_neff_data(ratios) %>%
)
ggsave("ratio.png", h = 3, w = 6.25)
(average_ess_ratio <- mean(ratios))
```

# ESS
Expand All @@ -101,6 +117,28 @@ data.frame(mcmc_summary) %>%
labs(x = "ESS", y = "Count")
ggsave("ess.png", h = 3, w = 6.25)
(ess_min <- min(mcmc_summary[, "n_eff"]))
(ess_lower <- quantile(mcmc_summary[, "n_eff"], 0.025))
(ess_median <- quantile(mcmc_summary[, "n_eff"], 0.50))
(ess_upper <- quantile(mcmc_summary[, "n_eff"], 0.975))
(ess_max <- max(mcmc_summary[, "n_eff"]))
```

Save outputs for use in manuscript:

```{r}
out <- list(
"max_rhat" = max_rhat,
"average_ess_ratio" = average_ess_ratio,
"ess_min" = ess_min,
"ess_lower" = ess_lower,
"ess_median" = ess_median,
"ess_upper" = ess_upper,
"ess_max" = ess_max
)
saveRDS(out, "out.rds")
```

# Autocorrelation
Expand All @@ -119,6 +157,14 @@ bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("logit")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("log_sigma")))
```

The parameters with the worst ESS and the worst $\hat R$:

```{r}
(plot <- bayesplot::mcmc_trace(mcmc, pars = c(names(which.min(mcmc_summary[, "n_eff"])), names(which.max(rhats)))))
ggsave("worst-trace.png", plot, h = 4, w = 6.25)
```

## Prevalence model

```{r}
Expand Down Expand Up @@ -167,9 +213,9 @@ Hence there is an unidentifiabiility here that leads to correlated posteriors.
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_alpha_a", "logit_phi_alpha_a"), diag_fun = "hist", off_diag_fun = "hex")
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_alpha_as", "logit_phi_alpha_as"), diag_fun = "hist", off_diag_fun = "hex")
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_rho_as", "logit_phi_rho_as"), diag_fun = "hist", off_diag_fun = "hex")
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_rho_a", "logit_phi_rho_a"), diag_fun = "hist", off_diag_fun = "hex")
(plot <- bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_rho_a", "logit_phi_rho_a"), diag_fun = "hist", off_diag_fun = "hex"))
ggsave("rho_a.png", h = 4, w = 6.25)
ggsave("rho_a.png", plot, h = 4, w = 6.25)
get_correlation <- function(par) {
cor(as.data.frame(rstan::extract(mcmc, c(paste0("log_sigma_", par), paste0("logit_phi_", par)))))[1, 2]
Expand All @@ -193,9 +239,9 @@ Looks like the answer is mostly yes.
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_rho_x", "logit_phi_rho_x"), diag_fun = "hist", off_diag_fun = "hex")
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_rho_xs", "logit_phi_rho_xs"), diag_fun = "hist", off_diag_fun = "hex")
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_alpha_xs", "logit_phi_alpha_xs"), diag_fun = "hist", off_diag_fun = "hex")
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_alpha_x", "logit_phi_alpha_x"), diag_fun = "hist", off_diag_fun = "hex")
(plot <- bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_alpha_x", "logit_phi_alpha_x"), diag_fun = "hist", off_diag_fun = "hex"))
ggsave("alpha_x.png", h = 4, w = 6.25)
ggsave("alpha_x.png", plot, h = 4, w = 6.25)
bym2_cor_df <- data.frame(
par = c("rho_x", "rho_xs", "alpha_x", "alpha_xs"),
Expand Down Expand Up @@ -224,10 +270,6 @@ neighbours_pairs_plot <- function(par, i) {
neighbour_pars <- paste0(par, "[", c(i, nb[[i]]), "]")
bayesplot::mcmc_pairs(mcmc, pars = neighbour_pars, diag_fun = "hist", off_diag_fun = "hex")
}
# area_merged %>%
# filter(area_level == max(area_level)) %>%
# print(n = Inf)
```

Here are Nkhata Bay and neighbours:
Expand Down
2 changes: 2 additions & 0 deletions src/naomi-simple_mcmc/orderly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ artefacts:
- rhat.png
- ratio.png
- ess.png
- out.rds
- rho_a.png
- alpha_x.png
- ar1-bym2-cor.csv
- worst-trace.png
- nuts-params.rds

resources:
Expand Down

0 comments on commit f656536

Please sign in to comment.