Skip to content

Commit 569b06c

Browse files
committed
1 parent 36f8c73 commit 569b06c

1 file changed

Lines changed: 31 additions & 11 deletions

File tree

R/loo.R

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,13 @@
5959
#' @param k_threshold Threshold for flagging estimates of the Pareto shape
6060
#' parameters \eqn{k} estimated by \code{loo}. See the \emph{How to proceed
6161
#' when \code{loo} gives warnings} section, below, for details.
62-
#'
62+
#' @param r_eff \code{TRUE} or \code{FALSE} indicating whether to compute the
63+
#' \code{r_eff} argument to pass to the \pkg{loo} package. If \code{TRUE},
64+
#' \pkg{rstanarm} will call \code{\link[loo]{relative_eff}} to compute the
65+
#' \code{r_eff} argument to pass to the \pkg{loo} package. If \code{FALSE}
66+
#' (the default), we avoid computing \code{r_eff}, which can be very slow. The
67+
#' reported ESS and MCSE estimates may be over-optimistic if the posterior
68+
#' draws are far from independent.
6369
#' @return The structure of the objects returned by \code{loo} and \code{waic}
6470
#' methods are documented in detail in the \strong{Value} section in
6571
#' \code{\link[loo]{loo}} and \code{\link[loo]{waic}} (from the \pkg{loo}
@@ -184,9 +190,15 @@ loo.stanreg <-
184190
...,
185191
cores = getOption("mc.cores", 1),
186192
save_psis = FALSE,
187-
k_threshold = NULL) {
188-
if (model_has_weights(x))
193+
k_threshold = NULL,
194+
r_eff = FALSE) {
195+
if (model_has_weights(x)) {
189196
recommend_exact_loo(reason = "model has weights")
197+
}
198+
199+
if (!r_eff) {
200+
r_eff <- NULL
201+
}
190202

191203
user_threshold <- !is.null(k_threshold)
192204
if (user_threshold) {
@@ -196,9 +208,9 @@ loo.stanreg <-
196208
}
197209

198210

199-
if (used.sampling(x)) # chain_id to pass to loo::relative_eff
211+
if (used.sampling(x)) {# chain_id to pass to loo::relative_eff
200212
chain_id <- chain_id_for_loo(x)
201-
else { # ir_idx to pass to ...
213+
} else { # ir_idx to pass to ...
202214
if (exists("ir_idx",x)) {
203215
ir_idx <- x$ir_idx
204216
} else if ("diagnostics" %in% names(x$stanfit@sim) &
@@ -212,7 +224,9 @@ loo.stanreg <-
212224

213225
if (is.stanjm(x)) {
214226
ll <- log_lik(x)
215-
r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
227+
if (!is.null(r_eff)) {
228+
r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
229+
}
216230
loo_x <-
217231
suppressWarnings(loo.matrix(
218232
ll,
@@ -223,7 +237,9 @@ loo.stanreg <-
223237
} else if (is.stanmvreg(x)) {
224238
M <- get_M(x)
225239
ll <- do.call("cbind", lapply(1:M, function(m) log_lik(x, m = m)))
226-
r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
240+
if (!is.null(r_eff)) {
241+
r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
242+
}
227243
loo_x <-
228244
suppressWarnings(loo.matrix(
229245
ll,
@@ -242,7 +258,9 @@ loo.stanreg <-
242258
)
243259
ll <- ll[,!cons, drop = FALSE]
244260
}
245-
r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
261+
if (!is.null(r_eff)) {
262+
r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores)
263+
}
246264
loo_x <-
247265
suppressWarnings(loo.matrix(
248266
ll,
@@ -256,7 +274,7 @@ loo.stanreg <-
256274
likfun <- function(data_i, draws) {
257275
exp(llfun(data_i, draws))
258276
}
259-
if (used.sampling(x)) {
277+
if (used.sampling(x) && !is.null(r_eff)) {
260278
r_eff <- loo::relative_eff(
261279
# using function method
262280
x = likfun,
@@ -266,12 +284,14 @@ loo.stanreg <-
266284
cores = cores,
267285
...
268286
)
269-
} else {
287+
} else if (!used.sampling(x)) {
270288
w_ir <- as.numeric(table(ir_idx))/length(ir_idx)
271289
ir_uidx <- which(!duplicated(ir_idx))
272290
draws <- args$draws
273291
data <- args$data
274-
r_eff <- pmin(sapply(1:dim(data)[1], function(i) {lik_i <- likfun(data[i,], draws)[ir_uidx]; var(lik_i)/(sum(w_ir^2*(lik_i-mean(lik_i))^2))}),length(ir_uidx))/length(ir_idx)
292+
if (!is.null(r_eff)) {
293+
r_eff <- pmin(sapply(1:dim(data)[1], function(i) {lik_i <- likfun(data[i,], draws)[ir_uidx]; var(lik_i)/(sum(w_ir^2*(lik_i-mean(lik_i))^2))}),length(ir_uidx))/length(ir_idx)
294+
}
275295
}
276296
loo_x <- suppressWarnings(
277297
loo.function(

0 commit comments

Comments
 (0)