Skip to content

Commit 7969803

Browse files
committed
update filter_map to handle arrays of tuples
1 parent e0145b6 commit 7969803

14 files changed

Lines changed: 358 additions & 180 deletions

stan/math/mix/functor/laplace_likelihood.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#ifndef STAN_MATH_MIX_FUNCTOR_LAPLACE_LIKELIHOOD_HPP
22
#define STAN_MATH_MIX_FUNCTOR_LAPLACE_LIKELIHOOD_HPP
33

4-
// #include <stan/math/mix/laplace/hessian_times_vector.hpp>
54
#include <stan/math/mix/functor/hessian_block_diag.hpp>
65
#include <stan/math/prim/functor.hpp>
76
#include <stan/math/prim/fun.hpp>
@@ -61,6 +60,14 @@ inline auto conditional_copy_and_promote(Args&&... args) {
6160
std::forward<decltype(inner_args)>(inner_args))...);
6261
},
6362
std::forward<decltype(arg)>(arg));
63+
} else if constexpr (is_std_vector_v<decltype(arg)>) {
64+
std::vector<decltype(conditional_copy_and_promote<Filter, PromotedType,
65+
CopyType>(arg[0]))> ret;
66+
for (std::size_t i = 0; i < arg.size(); ++i) {
67+
ret.push_back(conditional_copy_and_promote<Filter, PromotedType,
68+
CopyType>(arg[i]));
69+
}
70+
return ret;
6471
} else {
6572
if constexpr (CopyType == COPY_TYPE::DEEP) {
6673
return stan::math::eval(promote_scalar<PromotedType>(

stan/math/mix/functor/laplace_marginal_density.hpp

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,7 @@ inline void set_zero_adjoint(Output&& output) {
296296
} else if constexpr (is_stan_scalar_v<output_i_t>) {
297297
output_i.adj() = 0;
298298
} else {
299-
static_assert(1,
300-
"set_zero_adjoint missed!!! This is an internal "
301-
"error please report an issue on the Stan github");
299+
static_assert(1, "INTERNAL ERROR:(laplace_marginal_lpdf) set_zero_adjoints was not able to deduce the actiopns needed for the given type.");
302300
}
303301
},
304302
std::forward<Output>(output));
@@ -331,7 +329,7 @@ inline void collect_adjoints(Output& output, Input1&& precalc) {
331329
precalc_i.adj() = 0;
332330
}
333331
} else {
334-
static_assert(1, "We missed!!!");
332+
static_assert(1, "INTERNAL ERROR:(laplace_marginal_lpdf) collect_adjoints was not able to deduce the actiopns needed for the given type.");
335333
}
336334
},
337335
std::forward<Output>(output), std::forward<Input1>(precalc));
@@ -769,9 +767,10 @@ template <typename Output, typename Input1>
769767
inline void collect_adjoints(Output&& output, const vari* ret,
770768
Input1&& precalc) {
771769
if constexpr (is_tuple_v<Output>) {
772-
static_assert(1,
770+
static_assert(1,"INTERNAL ERROR:(laplace_marginal_lpdf)"
773771
"Accumulate Adjoints called on a tuple, but tuples cannot be "
774-
"on the reverse mode stack!");
772+
"on the reverse mode stack!"
773+
"This is an internal error, please report it to the stan github as an issue.");
775774
} else if constexpr (is_std_vector_v<Output>) {
776775
if constexpr (!is_var_v<value_type_t<Output>>) {
777776
const auto output_size = output.size();
@@ -809,10 +808,7 @@ inline void collect_adjoints(Output&& output, Input1&& precalc) {
809808
} else if constexpr (is_stan_scalar_v<output_i_t>) {
810809
output_i += precalc_i;
811810
} else {
812-
static_assert(1,
813-
"collect_adjoints was given an unexpected type! This "
814-
"is an internal bug. Please file an issue on Stan's "
815-
"github repository.");
811+
static_assert(1, "INTERNAL ERROR:(laplace_marginal_lpdf) collect_adjoints was not able to deduce the actiopns needed for the given type.");
816812
}
817813
},
818814
std::forward<Output>(output), std::forward<Input1>(precalc));
@@ -838,17 +834,12 @@ inline void copy_compute_s2(Output&& output, Input1&& precalc) {
838834
} else if constexpr (is_stan_scalar_v<output_i_t>) {
839835
output_i += (0.5 * precalc_i.adj());
840836
} else {
841-
static_assert(1, "We missed!!!");
837+
static_assert(1, "INTERNAL ERROR:(laplace_marginal_lpdf) copy_compute_s2 was not able to deduce the actiopns needed for the given type.");
842838
}
843839
},
844840
std::forward<Output>(output), std::forward<Input1>(precalc));
845841
}
846842

847-
template <typename T>
848-
static constexpr bool is_dbl_nothrow_constructible_v
849-
= std::is_nothrow_constructible<
850-
promote_scalar_t<double, std::decay_t<T>>>::value;
851-
852843
template <typename Output>
853844
inline constexpr auto make_zero(Output&& output) {
854845
if constexpr (is_tuple_v<Output>) {
@@ -857,7 +848,7 @@ inline constexpr auto make_zero(Output&& output) {
857848
} else if constexpr (is_std_vector_v<Output>) {
858849
if constexpr (!is_var_v<value_type_t<Output>>) {
859850
const auto output_size = output.size();
860-
arena_t<promote_scalar_t<double, Output>> ret;
851+
arena_t<std::vector<decltype(make_zero(output[0]))>> ret;
861852
ret.reserve(output_size);
862853
for (Eigen::Index i = 0; i < output_size; ++i) {
863854
ret.push_back(make_zero(output[i]));
@@ -897,10 +888,11 @@ inline void print_adjoint(Output&& output) {
897888
} else if constexpr (is_stan_scalar_v<Output>) {
898889
std::cout << "adj: " << output.adj() << std::endl;
899890
} else {
900-
static_assert(1, "print missed!!!");
891+
static_assert(1, "INTERNAL ERROR:(laplace_marginal_lpdf) print_adjoint was not able to deduce the actiopns needed for the given type.");
901892
}
902893
}
903894

895+
904896
template <typename Arg, typename Precalc>
905897
inline void laplace_tuple_collect_adjoints(var ret, Arg&& arg,
906898
Precalc&& precalc) {
@@ -912,6 +904,10 @@ inline void laplace_tuple_collect_adjoints(var ret, Arg&& arg,
912904
std::forward<decltype(inner_precalc)>(inner_precalc));
913905
},
914906
std::forward<Arg>(arg), std::forward<Precalc>(precalc));
907+
} else if constexpr (is_std_vector_containing_tuple_v<Arg>) {
908+
for (std::size_t i = 0; i < arg.size(); ++i) {
909+
laplace_tuple_collect_adjoints(ret, arg[i], precalc[i]);
910+
}
915911
} else {
916912
reverse_pass_callback(
917913
[vi = ret.vi_, arg_arena = to_arena(std::forward<Arg>(arg)),
@@ -1108,13 +1104,7 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
11081104
return std::forward<decltype(arg)>(arg);
11091105
},
11101106
covar_args_copy);
1111-
// std::cout << "\ncovar args: " << std::endl;
1112-
// print_adjoint(covar_args_filter);
11131107
collect_adjoints(covar_args_adj, covar_args_filter);
1114-
// std::cout << "\n______________\n";
1115-
// std::cout << "covar args adj: " << std::endl;
1116-
// print_adjoint(covar_args_adj);
1117-
// std::cout << "\n==============\n";
11181108
}();
11191109
}
11201110
if constexpr (ll_args_contain_var) {
@@ -1133,13 +1123,13 @@ inline auto laplace_marginal_density(const LLFun& ll_fun, LLTupleArgs&& ll_args,
11331123
}
11341124
var ret(lmd);
11351125
if constexpr (is_any_var_scalar_v<CovarTupleArgs>) {
1136-
auto covar_args_arena = stan::math::filter_map<is_any_var_scalar>(
1137-
[](auto&& arg) { return to_arena(arg); }, covar_args_refs);
1138-
laplace_tuple_collect_adjoints(ret, covar_args_arena, covar_args_adj);
1126+
auto covar_args_filter = stan::math::filter_map<is_any_var_scalar>(
1127+
[](auto&& arg) -> decltype(auto) { return arg; }, covar_args_refs);
1128+
laplace_tuple_collect_adjoints(ret, covar_args_filter, covar_args_adj);
11391129
}
11401130
if constexpr (ll_args_contain_var) {
11411131
auto ll_args_filter = stan::math::filter_map<is_any_var_scalar>(
1142-
[](auto&& arg) { return to_arena(arg); }, ll_args_refs);
1132+
[](auto&& arg) -> decltype(auto) { return arg; }, ll_args_refs);
11431133
laplace_tuple_collect_adjoints(ret, ll_args_filter, partial_parm);
11441134
}
11451135
return ret;

stan/math/prim/fun/promote_scalar.hpp

Lines changed: 25 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -10,103 +10,39 @@
1010
namespace stan {
1111
namespace math {
1212

13-
/**
14-
* Promote a scalar to another scalar type
15-
*
16-
* @tparam PromotionScalar scalar type of output.
17-
* @tparam UnPromotedType input type. `UnPromotedType` must be constructible
18-
* from `PromotionScalar`
19-
* @param x input scalar to be promoted to `PromotionScalar` type
20-
*/
21-
template <typename PromotionScalar, typename UnPromotedType,
22-
require_constructible_t<PromotionScalar, UnPromotedType>* = nullptr,
23-
require_not_same_t<PromotionScalar, UnPromotedType>* = nullptr,
24-
require_all_not_tuple_t<PromotionScalar, UnPromotedType>* = nullptr>
25-
inline constexpr auto promote_scalar(UnPromotedType&& x) {
26-
return PromotionScalar(std::forward<UnPromotedType>(x));
27-
}
28-
29-
/**
30-
* No-op overload when promoting a type's scalar to the type it already has.
31-
*
32-
* @tparam PromotionScalar scalar type of output.
33-
* @tparam UnPromotedType input type. `UnPromotedType`'s `scalar_type` must be
34-
* equal to `PromotionScalar`
35-
* @param x input
36-
*/
37-
template <
38-
typename PromotionScalar, typename UnPromotedType,
39-
require_same_t<PromotionScalar, scalar_type_t<UnPromotedType>>* = nullptr>
40-
inline constexpr auto promote_scalar(UnPromotedType&& x) noexcept {
41-
return std::forward<UnPromotedType>(x);
42-
}
43-
44-
/**
45-
* Promote the scalar type of an eigen matrix to the requested type.
46-
*
47-
* @tparam PromotionScalar scalar type of output.
48-
* @tparam UnPromotedType input type. The `PromotionScalar` type must be
49-
* constructible from `UnPromotedType`'s `scalar_type`
50-
* @param x input
51-
*/
52-
template <typename PromotionScalar, typename UnPromotedType,
53-
require_eigen_t<UnPromotedType>* = nullptr,
54-
require_not_same_t<PromotionScalar,
55-
value_type_t<UnPromotedType>>* = nullptr>
56-
inline auto promote_scalar(UnPromotedType&& x) {
57-
return x.template cast<PromotionScalar>();
58-
}
59-
60-
// Forward decl for iterating over tuples used in std::vector<tuple>
61-
template <typename PromotionScalars, typename UnPromotedTypes,
62-
require_all_tuple_t<PromotionScalars, UnPromotedTypes>* = nullptr,
63-
require_not_same_t<PromotionScalars, UnPromotedTypes>* = nullptr>
64-
inline constexpr promote_scalar_t<PromotionScalars, UnPromotedTypes>
65-
promote_scalar(UnPromotedTypes&& x);
6613

67-
/**
68-
* Promote the scalar type of an standard vector to the requested type.
69-
*
70-
* @tparam PromotionScalar scalar type of output.
71-
* @tparam UnPromotedType input type. The `PromotionScalar` type must be
72-
* constructible from `UnPromotedType`'s `scalar_type`
73-
* @param x input
74-
*/
75-
template <typename PromotionScalar, typename UnPromotedType,
76-
require_std_vector_t<UnPromotedType>* = nullptr,
77-
require_not_same_t<PromotionScalar,
78-
scalar_type_t<UnPromotedType>>* = nullptr>
79-
inline auto promote_scalar(UnPromotedType&& x) {
80-
const auto x_size = x.size();
81-
promote_scalar_t<PromotionScalar, UnPromotedType> ret(x_size);
82-
for (size_t i = 0; i < x_size; ++i) {
83-
ret[i] = promote_scalar<PromotionScalar>(x[i]);
84-
}
85-
return ret;
86-
}
87-
88-
/**
89-
* Promote the scalar type of a tuples elements to the requested types.
90-
*
91-
* @tparam PromotionScalars A tuple of scalar types that is the same size as the
92-
* tuple of `UnPromotedTypes`.
93-
* @tparam UnPromotedTypes tuple input. Each `PromotionScalars` element must be
94-
* constructible from it's associated element of `UnPromotedTypes` `scalar_type`
95-
* @param x input
96-
*/
97-
template <typename PromotionScalars, typename UnPromotedTypes,
98-
require_all_tuple_t<PromotionScalars, UnPromotedTypes>*,
99-
require_not_same_t<PromotionScalars, UnPromotedTypes>*>
100-
inline constexpr promote_scalar_t<PromotionScalars, UnPromotedTypes>
101-
promote_scalar(UnPromotedTypes&& x) {
102-
return index_apply<std::tuple_size<std::decay_t<UnPromotedTypes>>::value>(
14+
template <typename PromotionScalars, typename UnPromotedTypes>
15+
inline constexpr auto promote_scalar(UnPromotedTypes&& x) {
16+
if constexpr (std::is_same_v<PromotionScalars, scalar_type_t<UnPromotedTypes>>) {
17+
return std::forward<UnPromotedTypes>(x);
18+
} else if constexpr (is_tuple_v<PromotionScalars> && is_tuple_v<UnPromotedTypes>) {
19+
return index_apply<std::tuple_size<std::decay_t<UnPromotedTypes>>::value>(
10320
[&x](auto... Is) {
10421
return std::make_tuple(
10522
promote_scalar<std::decay_t<decltype(std::get<Is>(
10623
std::declval<PromotionScalars>()))>>(std::get<Is>(x))...);
10724
});
25+
} else if constexpr (is_tuple_v<UnPromotedTypes>) {
26+
return stan::math::apply([](auto&&... args) {
27+
return std::make_tuple(promote_scalar<PromotionScalars>(std::forward<decltype(args)>(args))...);
28+
}, std::forward<UnPromotedTypes>(x));
29+
} else if constexpr (is_std_vector_v<UnPromotedTypes>) {
30+
const auto x_size = x.size();
31+
promote_scalar_t<PromotionScalars, UnPromotedTypes> ret(x_size);
32+
for (size_t i = 0; i < x_size; ++i) {
33+
ret[i] = promote_scalar<PromotionScalars>(x[i]);
34+
}
35+
return ret;
36+
} else if constexpr (is_eigen_v<UnPromotedTypes>) {
37+
return std::forward<UnPromotedTypes>(x).template cast<PromotionScalars>();
38+
} else if constexpr (is_stan_scalar_v<UnPromotedTypes>) {
39+
return PromotionScalars(std::forward<UnPromotedTypes>(x));
40+
} else {
41+
static_assert(1, "Missed type in promote_scalar!");
42+
}
10843
}
10944

45+
11046
} // namespace math
11147
} // namespace stan
11248

0 commit comments

Comments
 (0)