@@ -1063,6 +1063,24 @@ process_init.default <- function(init, ...) {
10631063 return (init )
10641064}
10651065
1066+ # ' Remove the leftmost dimension if equal to 1
1067+ # ' @noRd
1068+ # ' @param x An array like object
1069+ .remove_leftmost_dim <- function (x ) {
1070+ dims <- dim(x )
1071+ if (length(dims ) == 1 ) {
1072+ return (drop(x ))
1073+ } else if (dims [1 ] == 1 ) {
1074+ new_dims <- dims [- 1 ]
1075+ # Create a call to subset the array, maintaining all remaining dimensions
1076+ subset_expr <- as.call(c(as.name(" [" ), list (x ), 1 , rep(TRUE , length(new_dims )), drop = FALSE ))
1077+ new_x <- eval(subset_expr )
1078+ return (array (new_x , dim = new_dims ))
1079+ } else {
1080+ return (x )
1081+ }
1082+ }
1083+
10661084# ' Write initial values to files if provided as posterior `draws` object
10671085# ' @noRd
10681086# ' @param init A type that inherits the `posterior::draws` class.
@@ -1097,9 +1115,13 @@ process_init.draws <- function(init, num_procs, model_variables = NULL,
10971115 draws_rvar <- posterior :: subset_draws(draws_rvar , variable = variable_names )
10981116 inits = lapply(1 : num_procs , function (draw_iter ) {
10991117 init_i = lapply(variable_names , function (var_name ) {
1100- x = drop(posterior :: draws_of(drop(
1101- posterior :: subset_draws(draws_rvar [[var_name ]], draw = draw_iter ))))
1102- return (x )
1118+ x = .remove_leftmost_dim(posterior :: draws_of(
1119+ posterior :: subset_draws(draws_rvar [[var_name ]], draw = draw_iter )))
1120+ if (model_variables $ parameters [[var_name ]]$ dimensions == 0 ) {
1121+ return (as.double(x ))
1122+ } else {
1123+ return (x )
1124+ }
11031125 })
11041126 bad_names = unlist(lapply(variable_names , function (var_name ) {
11051127 x = drop(posterior :: draws_of(drop(
@@ -1295,13 +1317,13 @@ process_init_approx <- function(init, num_procs, model_variables = NULL,
12951317 # Calculate unique draws based on 'lw' using base R functions
12961318 unique_draws = length(unique(draws_df $ lw ))
12971319 if (num_procs > unique_draws ) {
1298- if (inherits(init , " CmdStanPathfinder " )) {
1320+ if (inherits(init , " CmdStanPathfinder" )) {
12991321 algo_name = " Pathfinder "
13001322 extra_msg = " Try running Pathfinder with psis_resample=FALSE."
13011323 } else if (inherits(init , " CmdStanVB" )) {
13021324 algo_name = " CmdStanVB "
13031325 extra_msg = " "
1304- } else if (inherits(init , " CmdStanLaplace " )) {
1326+ } else if (inherits(init , " CmdStanLaplace" )) {
13051327 algo_name = " CmdStanLaplace "
13061328 extra_msg = " "
13071329 } else {
0 commit comments