5959# ' @param k_threshold Threshold for flagging estimates of the Pareto shape
6060# ' parameters \eqn{k} estimated by \code{loo}. See the \emph{How to proceed
6161# ' when \code{loo} gives warnings} section, below, for details.
62- # '
62+ # ' @param r_eff \code{TRUE} or \code{FALSE} indicating whether to compute the
63+ # ' \code{r_eff} argument to pass to the \pkg{loo} package. If \code{TRUE},
64+ # ' \pkg{rstanarm} will call \code{\link[loo]{relative_eff}} to compute the
65+ # ' \code{r_eff} argument to pass to the \pkg{loo} package. If \code{FALSE}
66+ # ' (the default), we avoid computing \code{r_eff}, which can be very slow. The
67+ # ' reported ESS and MCSE estimates may be over-optimistic if the posterior
68+ # ' draws are far from independent.
6369# ' @return The structure of the objects returned by \code{loo} and \code{waic}
6470# ' methods are documented in detail in the \strong{Value} section in
6571# ' \code{\link[loo]{loo}} and \code{\link[loo]{waic}} (from the \pkg{loo}
@@ -184,9 +190,15 @@ loo.stanreg <-
184190 ... ,
185191 cores = getOption(" mc.cores" , 1 ),
186192 save_psis = FALSE ,
187- k_threshold = NULL ) {
188- if (model_has_weights(x ))
193+ k_threshold = NULL ,
194+ r_eff = FALSE ) {
195+ if (model_has_weights(x )) {
189196 recommend_exact_loo(reason = " model has weights" )
197+ }
198+
199+ if (! r_eff ) {
200+ r_eff <- NULL
201+ }
190202
191203 user_threshold <- ! is.null(k_threshold )
192204 if (user_threshold ) {
@@ -196,9 +208,9 @@ loo.stanreg <-
196208 }
197209
198210
199- if (used.sampling(x )) # chain_id to pass to loo::relative_eff
211+ if (used.sampling(x )) { # chain_id to pass to loo::relative_eff
200212 chain_id <- chain_id_for_loo(x )
201- else { # ir_idx to pass to ...
213+ } else { # ir_idx to pass to ...
202214 if (exists(" ir_idx" ,x )) {
203215 ir_idx <- x $ ir_idx
204216 } else if (" diagnostics" %in% names(x $ stanfit @ sim ) &
@@ -212,7 +224,9 @@ loo.stanreg <-
212224
213225 if (is.stanjm(x )) {
214226 ll <- log_lik(x )
215- r_eff <- loo :: relative_eff(exp(ll ), chain_id = chain_id , cores = cores )
227+ if (! is.null(r_eff )) {
228+ r_eff <- loo :: relative_eff(exp(ll ), chain_id = chain_id , cores = cores )
229+ }
216230 loo_x <-
217231 suppressWarnings(loo.matrix(
218232 ll ,
@@ -223,7 +237,9 @@ loo.stanreg <-
223237 } else if (is.stanmvreg(x )) {
224238 M <- get_M(x )
225239 ll <- do.call(" cbind" , lapply(1 : M , function (m ) log_lik(x , m = m )))
226- r_eff <- loo :: relative_eff(exp(ll ), chain_id = chain_id , cores = cores )
240+ if (! is.null(r_eff )) {
241+ r_eff <- loo :: relative_eff(exp(ll ), chain_id = chain_id , cores = cores )
242+ }
227243 loo_x <-
228244 suppressWarnings(loo.matrix(
229245 ll ,
@@ -242,7 +258,9 @@ loo.stanreg <-
242258 )
243259 ll <- ll [,! cons , drop = FALSE ]
244260 }
245- r_eff <- loo :: relative_eff(exp(ll ), chain_id = chain_id , cores = cores )
261+ if (! is.null(r_eff )) {
262+ r_eff <- loo :: relative_eff(exp(ll ), chain_id = chain_id , cores = cores )
263+ }
246264 loo_x <-
247265 suppressWarnings(loo.matrix(
248266 ll ,
@@ -256,7 +274,7 @@ loo.stanreg <-
256274 likfun <- function (data_i , draws ) {
257275 exp(llfun(data_i , draws ))
258276 }
259- if (used.sampling(x )) {
277+ if (used.sampling(x ) && ! is.null( r_eff ) ) {
260278 r_eff <- loo :: relative_eff(
261279 # using function method
262280 x = likfun ,
@@ -266,12 +284,14 @@ loo.stanreg <-
266284 cores = cores ,
267285 ...
268286 )
269- } else {
287+ } else if ( ! used.sampling( x )) {
270288 w_ir <- as.numeric(table(ir_idx ))/ length(ir_idx )
271289 ir_uidx <- which(! duplicated(ir_idx ))
272290 draws <- args $ draws
273291 data <- args $ data
274- r_eff <- pmin(sapply(1 : dim(data )[1 ], function (i ) {lik_i <- likfun(data [i ,], draws )[ir_uidx ]; var(lik_i )/ (sum(w_ir ^ 2 * (lik_i - mean(lik_i ))^ 2 ))}),length(ir_uidx ))/ length(ir_idx )
292+ if (! is.null(r_eff )) {
293+ r_eff <- pmin(sapply(1 : dim(data )[1 ], function (i ) {lik_i <- likfun(data [i ,], draws )[ir_uidx ]; var(lik_i )/ (sum(w_ir ^ 2 * (lik_i - mean(lik_i ))^ 2 ))}),length(ir_uidx ))/ length(ir_idx )
294+ }
275295 }
276296 loo_x <- suppressWarnings(
277297 loo.function(
0 commit comments