Skip to content

Commit

Permalink
First try at producing PSIS k hat statistic #41
Browse files Browse the repository at this point in the history
  • Loading branch information
athowes committed May 11, 2023
1 parent 26c9213 commit 0ba4ce5
Showing 1 changed file with 80 additions and 18 deletions.
98 changes: 80 additions & 18 deletions src/naomi-simple_psis/psis.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Import these inference results as follows:
```{r}
tmb <- readRDS("depends/tmb.rds")
aghq <- readRDS("depends/aghq.rds")
tmbstan <- readRDS("depends/tmbstan.rds")å
tmbstan <- readRDS("depends/tmbstan.rds")
```

Check that the parameters (latent field, hyperparameters, model outputs) sampled from each of the four methods are the same:
Expand All @@ -42,7 +42,6 @@ For each sample we are going to want to evaluate the log-likelihood, so that we
Although we can extract a `TMB` objective function from `tmb` directly, it would correspond to the Laplace approximation.
Instead we use `objfull` obtained from a previous report, and load the `naomi_simple` DLL:


```{r}
TMB::compile("naomi_simple.cpp")
dyn.load(TMB::dynlib("naomi_simple"))
Expand All @@ -67,22 +66,25 @@ data.frame(
theme_minimal()
```

Need to write a function which reformats `TMB` or `tmbstan` samples to be evaluated using `objfull$fn`:
We would like to produce evalautions of the log-likelihood for the `TMB` and `aghq` samples as well.
To do this, we first create a function which reformats the samples to be evaluated using `objfull$fn`:

```{r}
par_samp_matrix <- function(sample) {
# Remove the Naomi outputs, ending with either _out or _ll
x <- sample[!(stringr::str_ends(names(sample), "_out") | stringr::str_ends(names(sample), "_ll"))]
# Reorder to be the same as objfull$par
x <- x[unique(names(objfull$par))]
par_samp_matrix <- function(sample, elements = unique(names(objfull$par))) {
# Keep and reorder to be the same as elements
x <- sample[elements]
# Bind rows together
do.call(rbind, x)
}
tmb_samples <- par_samp_matrix(tmb$fit$sample)
aghq_samples <- par_samp_matrix(aghq$quad$sample)
tmbstan_samples <- par_samp_matrix(tmbstan$mcmc$sample)
```

Now we can calculate the log-likelihood by applying this function across the rows of the sample matrices:

```{r}
tmb_ll <- apply(tmb_samples, 2, FUN = function(x) -1 * objfull$fn(x))
aghq_ll <- apply(aghq_samples, 2, FUN = function(x) -1 * objfull$fn(x))
tmbstan_ll <- apply(tmbstan_samples, 2, FUN = function(x) -1 * objfull$fn(x))
Expand All @@ -98,21 +100,81 @@ data.frame(
scale_colour_manual(values = c("#56B4E9", "#009E73", "#E69F00")) +
labs(y = "Log-likelihood", x = "Draw number", col = "Method") +
theme_minimal()
```

To check that this is working as expected, we can compare the answers produced this way to those obtained directly from `lp__` in Stan:

```{r}
data.frame(lp__ = sort(tmbstan_extract$lp__), objfull = sort(tmbstan_ll)) %>%
ggplot(aes(x = lp__, y = objfull)) +
geom_point(alpha = 0.4) +
labs(title = "Appears these points lie on a straight line") +
theme_minimal()
```

#' Check that these are the same
plot(sort(tmbstan_extract$lp__), sort(tmbstan_ll))
It's interesting to note the differing levels of the log-likelihoods between the three sampling methods.
Is there a way that this could be interpreted?
Stan samples from the typical set, whereas `aghq` and `TMB` are sampling from closer to the mode?

```{r}
-1 * mean(tmb_ll)
-1 * mean(aghq_ll)
-1 * mean(tmbstan_ll)
#' Run PSIS
#' Need to get log-likelihood under the proposal
#' A Gaussian for TMB
#' A mixture of Gaussians for aghq
lw <- tmb_ll
loo::psis(log_ratios = lw, r_eff = 1)
(mean(tmb_ll) - mean(tmbstan_ll)) / 24
```

To run PSIS we need to get the log-likelihoods under the proposal, as well as the log-likelihoods under the target.
The proposal for `TMB` is a multivariate Gaussian, but only for the random effects:

```{r}
r <- tmb$fit$obj$env$random # Indices of the random effects
par_r <- tmb$fit$par.full[r] # Mode of the random effects
# Hessian of the random effects (this has crashed on me before, so being safe)
safe <- parallel::mcparallel({tmb$fit$obj$env$spHess(tmb$fit$par.full, random = TRUE)})
hess_r <- parallel::mccollect(safe, wait = TRUE, timeout = 0, intermediate = FALSE)
```

For `aghq` it's a mixture of Gaussians:

```{r}
aghq_modes <- aghq$quad$modesandhessians$mode
aghq_hessians <- aghq$quad$modesandhessians$H
aghq_lambda <- exp(aghq$quad$normalized_posterior$nodesandweights$logpost_normalized) * aghq$quad$normalized_posterior$nodesandweights$weights
```

`aghq_lambda` contains the multinomial weights, which we verify are in fact weights here:

```{r}
stopifnot(sum(aghq_lambda) - 1 < 10E-9)
stopifnot(all(aghq_lambda > 0))
```

Now to construct the log-likelihood, we will use the log-sum-exp trick:

```{r}
aghq_chol <- lapply(aghq_hessians, Matrix::Cholesky)
aghq_latent_samples <- par_samp_matrix(aghq$quad$sample, elements = unique(names(aghq_modes[[1]])))
start <- Sys.time()
mvn_ll <- mapply(sparseMVN::dmvn.sparse, mu = aghq_modes, CH = aghq_chol, MoreArgs = list(x = t(aghq_latent_samples), prec = TRUE, log = TRUE))
end <- Sys.time()
#' Takes around 2 minutes
end - start
aghq_q_ll <- apply(mvn_ll, 1, FUN = function(row) matrixStats::logSumExp(log(aghq_weights) + row))
plot(aghq_q_ll)
```

Now run the PSIS with `aghq`.
The importance ratios are $\log(p(\theta, y)) - \log(q(\theta))$:

```{r}
r <- - aghq_ll - aghq_q_ll
plot(r)
lw <- aghq_ll
loo::psis(log_ratios = lw, r_eff = 1)
aghq_psis <- loo::psis(log_ratios = lw, r_eff = 1)
aghq_psis$diagnostics
```

0 comments on commit 0ba4ce5

Please sign in to comment.