|
5 | 5 | #include <stan/math/prim/fun/constants.hpp> |
6 | 6 | #include <stan/math/prim/fun/Eigen.hpp> |
7 | 7 | #include <stan/math/prim/fun/log1p_exp.hpp> |
| 8 | +#include <stan/math/prim/functor/apply_scalar_binary.hpp> |
8 | 9 | #include <cmath> |
9 | 10 | #include <vector> |
10 | 11 |
|
@@ -47,7 +48,8 @@ namespace math { |
47 | 48 | * @param a the first variable |
48 | 49 | * @param b the second variable |
49 | 50 | */ |
50 | | -template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr> |
| 51 | +template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr, |
| 52 | + require_all_stan_scalar_t<T1, T2>* = nullptr> |
51 | 53 | inline return_type_t<T1, T2> log_sum_exp(const T2& a, const T1& b) { |
52 | 54 | if (a == NEGATIVE_INFTY) { |
53 | 55 | return b; |
@@ -91,6 +93,22 @@ inline auto log_sum_exp(const T& x) { |
91 | 93 | }); |
92 | 94 | } |
93 | 95 |
|
| 96 | +/** |
| 97 | + * Enables the vectorized application of the log_sum_exp function, |
| 98 | + * when the first and/or second arguments are containers. |
| 99 | + * |
| 100 | + * @tparam T1 type of first input |
| 101 | + * @tparam T2 type of second input |
| 102 | + * @param a First input |
| 103 | + * @param b Second input |
| 104 | + * @return log_sum_exp function applied to the two inputs. |
| 105 | + */ |
| 106 | +template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr> |
| 107 | +inline auto log_sum_exp(const T1& a, const T2& b) { |
| 108 | + return apply_scalar_binary( |
| 109 | + a, b, [](const auto& c, const auto& d) { return log_sum_exp(c, d); }); |
| 110 | +} |
| 111 | + |
94 | 112 | } // namespace math |
95 | 113 | } // namespace stan |
96 | 114 |
|
|
0 commit comments