Skip to content

Commit

Permalink
Better KS summary plot (still some workshopping to go), worst paramet…
Browse files Browse the repository at this point in the history
…ers automatically plotted #55
  • Loading branch information
athowes committed Jul 21, 2023
1 parent e8a8f9b commit 7927613
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 251 deletions.
404 changes: 217 additions & 187 deletions docs/ks.html

Large diffs are not rendered by default.

37 changes: 0 additions & 37 deletions src/naomi-simple_ks/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,40 +199,3 @@ ks_plot <- function(ks_df, par, method1 = "TMB", method2 = "aghq", alpha = 0.5)
scatterplot + jitterplot +
plot_layout(widths = c(2, 1))
}

#' Create a density plot and ridgeplot of the KS test statistics
#'
#' @param ks_df The output of `to_ks_df`
#' @param par Parameter name (only used for labelling)
#' @param method1 Samples from this method will be used as the first entry in the KS test
#' @param method2 Samples from this method will be used as the second entry in the KS test
#' @return A `ggplot2` object
ks_plot_many <- function(ks_summary, method1, method2) {
ks_method1 <- paste0("KS(", method1, ", tmbstan)")
ks_method2 <- paste0("KS(", method2, ", tmbstan)")

xy_length <- min(1, max(ks_summary[[ks_method1]], ks_summary[[ks_method2]]) + 0.03)

densityplot <- ggplot() +
stat_density_2d(data = ks_summary, aes(x = .data[[ks_method1]], y = .data[[ks_method2]], linetype = type), col = "black") +
xlim(0, xy_length) +
ylim(0, xy_length) +
geom_abline(intercept = 0, slope = 1, linetype = "dashed", alpha = 0.5) +
scale_linetype_manual(values = c("solid", "dashed")) +
labs(x = ks_method1, y = ks_method2) +
theme_minimal() +
guides(fill = "none") +
theme(legend.position = "bottom")

ks_summary[["KS difference"]] <- ks_summary[[ks_method1]] - ks_summary[[ks_method2]]

ridgeplot <- ggplot(ks_summary, aes(y = type, x = `KS difference`)) +
ggridges::geom_density_ridges(alpha = 0.7, fill = NA, aes(linetype = type)) +
coord_flip() +
scale_linetype_manual(values = c("solid", "dashed")) +
labs(y = "", x = paste0(ks_method1, " - ", ks_method2)) +
guides(linetype = "none") +
theme_minimal()

densityplot + ridgeplot
}
71 changes: 44 additions & 27 deletions src/naomi-simple_ks/ks.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ theta_names <- names(tmb$fit$obj$env$par[-r])
dict <- data.frame(
parname = c(unique(x_names), unique(theta_names)),
type = c(rep("Latent field", length(unique(x_names))), rep("Hyper", length(unique(theta_names))))
type = c(rep("Latent", length(unique(x_names))), rep("Hyper", length(unique(theta_names))))
)
ks_df <- lapply(unique(names(tmb$fit$obj$env$par)), to_ks_df) %>%
Expand Down Expand Up @@ -485,27 +485,42 @@ options(dplyr.summarise.inform = FALSE)
ks_summary <- ks_df %>%
group_by(method, parname, type) %>%
summarise(ks = signif(mean(ks), 4)) %>%
pivot_wider(names_from = "method", values_from = "ks") %>%
rename(
"Parameter" = "parname",
"KS(aghq, tmbstan)" = "aghq",
"KS(TMB, tmbstan)" = "TMB",
)
summarise(
ks = mean(ks),
size = n()
) %>%
pivot_wider(names_from = "method", values_from = "ks")
saveRDS(ks_summary, "ks-summary.rds")
ks_summary %>%
filter(type != "Hyper") %>%
ks_plot_many("TMB", "aghq")
extended_cbpalette <- colorRampPalette(multi.utils::cbpalette())
ks_summary_latent <- filter(ks_summary, type == "Latent field")
xy_length <- min(1, max(ks_summary_latent$aghq, ks_summary_latent$TMB) + 0.03)
ks_summary_latent %>%
ggplot(aes(x = TMB, y = aghq, col = parname, size = size)) +
geom_point(alpha = 0.4) +
xlim(0, xy_length) +
ylim(0, xy_length) +
scale_color_manual(values = extended_cbpalette(n = 20)) +
scale_size_continuous(breaks = c(2, 10, 32), labels = c(2, 10, 32)) +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
labs(x = "KS(TMB, NUTS)", y = "KS(PCA-AGHQ, NUTS)", col = "Parameter", size = "Length", subtitle = "Smaller KS values indicate higher accuracy") +
theme_minimal() +
guides(size = guide_legend(title.position = "top", direction = "vertical"), col = guide_legend(title.position = "top")) +
theme(legend.position = "bottom", legend.title = element_text(size = rel(0.9)), legend.text = element_text(size = rel(0.7)))
ggsave("ks-summary.png", h = 5, w = 6.25, bg = "white")
```

```{r}
ks_summary %>%
group_by(type) %>%
summarise(
TMB = mean(`KS(TMB, tmbstan)`),
aghq = mean(`KS(aghq, tmbstan)`)
TMB = mean(TMB),
aghq = mean(aghq)
) %>%
gt::gt()
```
Expand All @@ -514,8 +529,8 @@ ks_summary %>%

```{r}
ks_summary %>%
filter(type == "Latent field") %>%
mutate(diff = signif(`KS(TMB, tmbstan)` - `KS(aghq, tmbstan)`, 3)) %>%
filter(type == "Latent") %>%
mutate(diff = signif(TMB - aghq, 3)) %>%
DT::datatable()
```

Expand All @@ -532,25 +547,27 @@ ks_df_wide <- ks_df %>%
### Nodes where `TMB` beats `aghq`

```{r}
ks_df_wide %>%
arrange(desc(diff)) %>%
head(n = 10)
(tmb_beats_aghq <- ks_df_wide %>%
filter(!is.na(type)) %>%
arrange(diff) %>%
head(n = 10))
histogram_and_ecdf(par = "us_alpha_xs", i = 18)
histogram_and_ecdf(par = "us_alpha_xs", i = 16)
histogram_and_ecdf(par = "u_alpha_xs", i = 13)
histogram_and_ecdf(par = tmb_beats_aghq$parname[1], i = tmb_beats_aghq$index[1])
histogram_and_ecdf(par = tmb_beats_aghq$parname[2], i = tmb_beats_aghq$index[2])
histogram_and_ecdf(par = tmb_beats_aghq$parname[3], i = tmb_beats_aghq$index[3])
```

### Nodes where `aghq` beats `TMB`

```{r}
ks_df_wide %>%
arrange(diff) %>%
head(n = 10)
(aghq_beats_tmb <- ks_df_wide %>%
filter(!is.na(type)) %>%
arrange(desc(diff)) %>%
head(n = 10))
histogram_and_ecdf(par = "ui_anc_rho_x", i = 27)
histogram_and_ecdf(par = "ui_anc_alpha_x", i = 19)
histogram_and_ecdf(par = "u_alpha_x", i = 18)
histogram_and_ecdf(par = aghq_beats_tmb$parname[1], i = aghq_beats_tmb$index[1])
histogram_and_ecdf(par = aghq_beats_tmb$parname[2], i = aghq_beats_tmb$index[2])
histogram_and_ecdf(par = aghq_beats_tmb$parname[3], i = aghq_beats_tmb$index[3])
```

## Correlation between KS values and ESS
Expand Down
1 change: 1 addition & 0 deletions src/naomi-simple_ks/orderly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ artefacts:
- data:
description: Saved plots
filenames:
- ks-summary.png
- ks-ess.png

resources:
Expand Down

0 comments on commit 7927613

Please sign in to comment.