11# Leave-one-out cross-validation
2- #
2+ #
33# The \code{loo} method for stanfit objects ---a wrapper around the
44# \code{loo.array} method from the \pkg{loo} package--- computes approximate
55# leave-one-out cross-validation using Pareto smoothed importance sampling
66# (PSIS-LOO CV).
7- #
7+ #
88# @param x stanfit object
99# @param pars Name of parameter, transformed parameter, or generated quantity in
1010# the Stan program corresponding to the pointwise log-likelihood. If not
2020# @param k_threshold Threshold value for Pareto k values above which
2121# the moment matching algorithm is used. If \code{moment_match} is \code{FALSE},
2222# this is ignored.
23+ # @param r_eff Whether to compute r_eff to pass to loo package.
2324# @param ... Ignored.
24- #
25+ #
2526# @details Stan does not automatically compute and store the log-likelihood. It
2627# is up to the user to incorporate it into the Stan program if it is to be
2728# extracted after fitting the model. In a Stan model, the pointwise log
@@ -47,31 +48,37 @@ loo.stanfit <-
4748 cores = getOption(" mc.cores" , 1 ),
4849 moment_match = FALSE ,
4950 k_threshold = 0.7 ,
51+ r_eff = FALSE ,
5052 ... ) {
5153 stopifnot(length(pars ) == 1L )
5254 stopifnot(is.logical(save_psis ))
5355 stopifnot(is.logical(moment_match ))
5456 stopifnot(is.numeric(k_threshold ))
55-
57+ stopifnot(is.logical(r_eff ))
58+
5659 LLarray <- loo :: extract_log_lik(stanfit = x ,
5760 parameter_name = pars ,
5861 merge_chains = FALSE )
59- r_eff <- loo :: relative_eff(x = exp(LLarray ), cores = cores )
60-
62+
63+ if (! r_eff ) {
64+ r_eff <- NULL
65+ } else {
66+ r_eff <- loo :: relative_eff(x = exp(LLarray ), cores = cores )
67+ }
68+
6169 if (moment_match ) {
6270 loo <- suppressWarnings(loo :: loo.array(LLarray ,
6371 r_eff = r_eff ,
6472 cores = cores ,
6573 save_psis = save_psis ))
66-
74+
6775 x_array <- as.array(x )
6876 chain_id <- rep(seq(dim(x_array )[2 ]),each = dim(x_array )[1 ])
6977 loo <- loo_moment_match.stanfit(
7078 x , loo = loo , chain_id = chain_id , k_threshold = k_threshold ,
7179 cores = cores , parameter_name = pars , ...
7280 )
73- }
74- else {
81+ } else {
7582 loo <- loo :: loo.array(LLarray ,
7683 r_eff = r_eff ,
7784 cores = cores ,
0 commit comments