Skip to content

Commit a6a4f55

Browse files
authored
Merge pull request #1170 from stan-dev/loo-r_eff-argument
Add r_eff argument to loo method
2 parents 32373c3 + b42fde1 commit a6a4f55

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

rstan/rstan/R/loo.R

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Leave-one-out cross-validation
2-
#
2+
#
33
# The \code{loo} method for stanfit objects ---a wrapper around the
44
# \code{loo.array} method from the \pkg{loo} package--- computes approximate
55
# leave-one-out cross-validation using Pareto smoothed importance sampling
66
# (PSIS-LOO CV).
7-
#
7+
#
88
# @param x stanfit object
99
# @param pars Name of parameter, transformed parameter, or generated quantity in
1010
# the Stan program corresponding to the pointwise log-likelihood. If not
@@ -20,8 +20,9 @@
2020
# @param k_threshold Threshold value for Pareto k values above which
2121
# the moment matching algorithm is used. If \code{moment_match} is \code{FALSE},
2222
# this is ignored.
23+
# @param r_eff Whether to compute r_eff to pass to loo package.
2324
# @param ... Ignored.
24-
#
25+
#
2526
# @details Stan does not automatically compute and store the log-likelihood. It
2627
# is up to the user to incorporate it into the Stan program if it is to be
2728
# extracted after fitting the model. In a Stan model, the pointwise log
@@ -47,31 +48,37 @@ loo.stanfit <-
4748
cores = getOption("mc.cores", 1),
4849
moment_match = FALSE,
4950
k_threshold = 0.7,
51+
r_eff = FALSE,
5052
...) {
5153
stopifnot(length(pars) == 1L)
5254
stopifnot(is.logical(save_psis))
5355
stopifnot(is.logical(moment_match))
5456
stopifnot(is.numeric(k_threshold))
55-
57+
stopifnot(is.logical(r_eff))
58+
5659
LLarray <- loo::extract_log_lik(stanfit = x,
5760
parameter_name = pars,
5861
merge_chains = FALSE)
59-
r_eff <- loo::relative_eff(x = exp(LLarray), cores = cores)
60-
62+
63+
if (!r_eff) {
64+
r_eff <- NULL
65+
} else {
66+
r_eff <- loo::relative_eff(x = exp(LLarray), cores = cores)
67+
}
68+
6169
if (moment_match) {
6270
loo <- suppressWarnings(loo::loo.array(LLarray,
6371
r_eff = r_eff,
6472
cores = cores,
6573
save_psis = save_psis))
66-
74+
6775
x_array <- as.array(x)
6876
chain_id <- rep(seq(dim(x_array)[2]),each = dim(x_array)[1])
6977
loo <- loo_moment_match.stanfit(
7078
x, loo = loo, chain_id = chain_id, k_threshold = k_threshold,
7179
cores = cores, parameter_name = pars, ...
7280
)
73-
}
74-
else {
81+
} else {
7582
loo <- loo::loo.array(LLarray,
7683
r_eff = r_eff,
7784
cores = cores,

rstan/rstan/man/stanfit-method-loo.Rd

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ sampling (Vehtari, Gelman, and Gabry, 2017a,2017b).
2020
cores = getOption("mc.cores", 1),
2121
moment_match = FALSE,
2222
k_threshold = 0.7,
23+
r_eff = FALSE,
2324
\dots)
2425
}
2526

@@ -43,6 +44,14 @@ sampling (Vehtari, Gelman, and Gabry, 2017a,2017b).
4344
\item{k_threshold}{Threshold value for Pareto k values above which
4445
the moment matching algorithm is used. If \code{moment_match} is \code{FALSE},
4546
this is ignored.}
47+
\item{r_eff}{\code{TRUE} or \code{FALSE} indicating whether to compute the
48+
\code{r_eff} argument to pass to the \pkg{loo} package. If \code{TRUE},
49+
will call \code{loo::relative_eff()}. If \code{FALSE}
50+
(the default), we avoid computing \code{r_eff}, which can be very slow.
51+
\code{r_eff} measures the amount of autocorrelation in MCMC draws, and is
52+
used to compute more accurate ESS and MCSE estimates for pointwise and
53+
total ELPDs. When \code{r_eff=FALSE}, the reported ESS and MCSE estimates
54+
may be over-optimistic if the posterior draws are far from independent.}
4655
\item{\dots}{Ignored.}
4756
}
4857

0 commit comments

Comments
 (0)