diff --git a/rstan/rstan/R/loo.R b/rstan/rstan/R/loo.R index 2be5668e2..5e4df9c36 100644 --- a/rstan/rstan/R/loo.R +++ b/rstan/rstan/R/loo.R @@ -1,10 +1,10 @@ # Leave-one-out cross-validation -# +# # The \code{loo} method for stanfit objects ---a wrapper around the # \code{loo.array} method from the \pkg{loo} package--- computes approximate # leave-one-out cross-validation using Pareto smoothed importance sampling # (PSIS-LOO CV). -# +# # @param x stanfit object # @param pars Name of parameter, transformed parameter, or generated quantity in # the Stan program corresponding to the pointwise log-likelihood. If not @@ -20,8 +20,9 @@ # @param k_threshold Threshold value for Pareto k values above which # the moment matching algorithm is used. If \code{moment_match} is \code{FALSE}, # this is ignored. +# @param r_eff Whether to compute r_eff to pass to loo package. # @param ... Ignored. -# +# # @details Stan does not automatically compute and store the log-likelihood. It # is up to the user to incorporate it into the Stan program if it is to be # extracted after fitting the model. In a Stan model, the pointwise log @@ -47,31 +48,37 @@ loo.stanfit <- cores = getOption("mc.cores", 1), moment_match = FALSE, k_threshold = 0.7, + r_eff = FALSE, ...) { stopifnot(length(pars) == 1L) stopifnot(is.logical(save_psis)) stopifnot(is.logical(moment_match)) stopifnot(is.numeric(k_threshold)) - + stopifnot(is.logical(r_eff)) + LLarray <- loo::extract_log_lik(stanfit = x, parameter_name = pars, merge_chains = FALSE) - r_eff <- loo::relative_eff(x = exp(LLarray), cores = cores) - + + if (!r_eff) { + r_eff <- NULL + } else { + r_eff <- loo::relative_eff(x = exp(LLarray), cores = cores) + } + if (moment_match) { loo <- suppressWarnings(loo::loo.array(LLarray, r_eff = r_eff, cores = cores, save_psis = save_psis)) - + x_array <- as.array(x) chain_id <- rep(seq(dim(x_array)[2]),each = dim(x_array)[1]) loo <- loo_moment_match.stanfit( x, loo = loo, chain_id = chain_id, k_threshold = k_threshold, cores = cores, parameter_name = pars, ... ) - } - else { + } else { loo <- loo::loo.array(LLarray, r_eff = r_eff, cores = cores, diff --git a/rstan/rstan/man/stanfit-method-loo.Rd b/rstan/rstan/man/stanfit-method-loo.Rd index e11837cad..973baae8e 100644 --- a/rstan/rstan/man/stanfit-method-loo.Rd +++ b/rstan/rstan/man/stanfit-method-loo.Rd @@ -20,6 +20,7 @@ sampling (Vehtari, Gelman, and Gabry, 2017a,2017b). cores = getOption("mc.cores", 1), moment_match = FALSE, k_threshold = 0.7, + r_eff = FALSE, \dots) } @@ -43,6 +44,14 @@ sampling (Vehtari, Gelman, and Gabry, 2017a,2017b). \item{k_threshold}{Threshold value for Pareto k values above which the moment matching algorithm is used. If \code{moment_match} is \code{FALSE}, this is ignored.} + \item{r_eff}{\code{TRUE} or \code{FALSE} indicating whether to compute the + \code{r_eff} argument to pass to the \pkg{loo} package. If \code{TRUE}, + will call \code{loo::relative_eff()}. If \code{FALSE} + (the default), we avoid computing \code{r_eff}, which can be very slow. + \code{r_eff} measures the amount of autocorrelation in MCMC draws, and is + used to compute more accurate ESS and MCSE estimates for pointwise and + total ELPDs. When \code{r_eff=FALSE}, the reported ESS and MCSE estimates + may be over-optimistic if the posterior draws are far from independent.} \item{\dots}{Ignored.} }