@@ -296,9 +296,7 @@ inline void set_zero_adjoint(Output&& output) {
296296 } else if constexpr (is_stan_scalar_v<output_i_t >) {
297297 output_i.adj () = 0 ;
298298 } else {
299- static_assert (1 ,
300- " set_zero_adjoint missed!!! This is an internal "
301- " error please report an issue on the Stan github" );
299+ static_assert (1 , " INTERNAL ERROR:(laplace_marginal_lpdf) set_zero_adjoints was not able to deduce the actiopns needed for the given type." );
302300 }
303301 },
304302 std::forward<Output>(output));
@@ -331,7 +329,7 @@ inline void collect_adjoints(Output& output, Input1&& precalc) {
331329 precalc_i.adj () = 0 ;
332330 }
333331 } else {
334- static_assert (1 , " We missed!!! " );
332+ static_assert (1 , " INTERNAL ERROR:(laplace_marginal_lpdf) collect_adjoints was not able to deduce the actiopns needed for the given type. " );
335333 }
336334 },
337335 std::forward<Output>(output), std::forward<Input1>(precalc));
@@ -769,9 +767,10 @@ template <typename Output, typename Input1>
769767inline void collect_adjoints (Output&& output, const vari* ret,
770768 Input1&& precalc) {
771769 if constexpr (is_tuple_v<Output>) {
772- static_assert (1 ,
770+ static_assert (1 ," INTERNAL ERROR:(laplace_marginal_lpdf) "
773771 " Accumulate Adjoints called on a tuple, but tuples cannot be "
774- " on the reverse mode stack!" );
772+ " on the reverse mode stack!"
773+ " This is an internal error, please report it to the stan github as an issue." );
775774 } else if constexpr (is_std_vector_v<Output>) {
776775 if constexpr (!is_var_v<value_type_t <Output>>) {
777776 const auto output_size = output.size ();
@@ -809,10 +808,7 @@ inline void collect_adjoints(Output&& output, Input1&& precalc) {
809808 } else if constexpr (is_stan_scalar_v<output_i_t >) {
810809 output_i += precalc_i;
811810 } else {
812- static_assert (1 ,
813- " collect_adjoints was given an unexpected type! This "
814- " is an internal bug. Please file an issue on Stan's "
815- " github repository." );
811+ static_assert (1 , " INTERNAL ERROR:(laplace_marginal_lpdf) collect_adjoints was not able to deduce the actiopns needed for the given type." );
816812 }
817813 },
818814 std::forward<Output>(output), std::forward<Input1>(precalc));
@@ -838,17 +834,12 @@ inline void copy_compute_s2(Output&& output, Input1&& precalc) {
838834 } else if constexpr (is_stan_scalar_v<output_i_t >) {
839835 output_i += (0.5 * precalc_i.adj ());
840836 } else {
841- static_assert (1 , " We missed!!! " );
837+ static_assert (1 , " INTERNAL ERROR:(laplace_marginal_lpdf) copy_compute_s2 was not able to deduce the actiopns needed for the given type. " );
842838 }
843839 },
844840 std::forward<Output>(output), std::forward<Input1>(precalc));
845841}
846842
847- template <typename T>
848- static constexpr bool is_dbl_nothrow_constructible_v
849- = std::is_nothrow_constructible<
850- promote_scalar_t <double , std::decay_t <T>>>::value;
851-
852843template <typename Output>
853844inline constexpr auto make_zero (Output&& output) {
854845 if constexpr (is_tuple_v<Output>) {
@@ -857,7 +848,7 @@ inline constexpr auto make_zero(Output&& output) {
857848 } else if constexpr (is_std_vector_v<Output>) {
858849 if constexpr (!is_var_v<value_type_t <Output>>) {
859850 const auto output_size = output.size ();
860- arena_t <promote_scalar_t < double , Output >> ret;
851+ arena_t <std::vector< decltype ( make_zero (output[ 0 ])) >> ret;
861852 ret.reserve (output_size);
862853 for (Eigen::Index i = 0 ; i < output_size; ++i) {
863854 ret.push_back (make_zero (output[i]));
@@ -897,10 +888,11 @@ inline void print_adjoint(Output&& output) {
897888 } else if constexpr (is_stan_scalar_v<Output>) {
898889 std::cout << " adj: " << output.adj () << std::endl;
899890 } else {
900- static_assert (1 , " print missed!!! " );
891+ static_assert (1 , " INTERNAL ERROR:(laplace_marginal_lpdf) print_adjoint was not able to deduce the actiopns needed for the given type. " );
901892 }
902893}
903894
895+
904896template <typename Arg, typename Precalc>
905897inline void laplace_tuple_collect_adjoints (var ret, Arg&& arg,
906898 Precalc&& precalc) {
@@ -912,6 +904,10 @@ inline void laplace_tuple_collect_adjoints(var ret, Arg&& arg,
912904 std::forward<decltype (inner_precalc)>(inner_precalc));
913905 },
914906 std::forward<Arg>(arg), std::forward<Precalc>(precalc));
907+ } else if constexpr (is_std_vector_containing_tuple_v<Arg>) {
908+ for (std::size_t i = 0 ; i < arg.size (); ++i) {
909+ laplace_tuple_collect_adjoints (ret, arg[i], precalc[i]);
910+ }
915911 } else {
916912 reverse_pass_callback (
917913 [vi = ret.vi_ , arg_arena = to_arena (std::forward<Arg>(arg)),
@@ -1108,13 +1104,7 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
11081104 return std::forward<decltype (arg)>(arg);
11091105 },
11101106 covar_args_copy);
1111- // std::cout << "\ncovar args: " << std::endl;
1112- // print_adjoint(covar_args_filter);
11131107 collect_adjoints (covar_args_adj, covar_args_filter);
1114- // std::cout << "\n______________\n";
1115- // std::cout << "covar args adj: " << std::endl;
1116- // print_adjoint(covar_args_adj);
1117- // std::cout << "\n==============\n";
11181108 }();
11191109 }
11201110 if constexpr (ll_args_contain_var) {
@@ -1133,13 +1123,13 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
11331123 }
11341124 var ret (lmd);
11351125 if constexpr (is_any_var_scalar_v<CovarTupleArgs>) {
1136- auto covar_args_arena = stan::math::filter_map<is_any_var_scalar>(
1137- [](auto && arg) { return to_arena ( arg) ; }, covar_args_refs);
1138- laplace_tuple_collect_adjoints (ret, covar_args_arena , covar_args_adj);
1126+ auto covar_args_filter = stan::math::filter_map<is_any_var_scalar>(
1127+ [](auto && arg) -> decltype ( auto ) { return arg; }, covar_args_refs);
1128+ laplace_tuple_collect_adjoints (ret, covar_args_filter , covar_args_adj);
11391129 }
11401130 if constexpr (ll_args_contain_var) {
11411131 auto ll_args_filter = stan::math::filter_map<is_any_var_scalar>(
1142- [](auto && arg) { return to_arena ( arg) ; }, ll_args_refs);
1132+ [](auto && arg) -> decltype ( auto ) { return arg; }, ll_args_refs);
11431133 laplace_tuple_collect_adjoints (ret, ll_args_filter, partial_parm);
11441134 }
11451135 return ret;
0 commit comments