Skip to content

Commit 20c244e

Browse files
committed
Vectorise binary log_sum_exp
1 parent 38289cd commit 20c244e

2 files changed

Lines changed: 32 additions & 1 deletion

File tree

stan/math/prim/fun/log_sum_exp.hpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stan/math/prim/fun/constants.hpp>
66
#include <stan/math/prim/fun/Eigen.hpp>
77
#include <stan/math/prim/fun/log1p_exp.hpp>
8+
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
89
#include <cmath>
910
#include <vector>
1011

@@ -47,7 +48,8 @@ namespace math {
4748
* @param a the first variable
4849
* @param b the second variable
4950
*/
50-
template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr>
51+
template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr,
52+
require_all_stan_scalar_t<T1, T2>* = nullptr>
5153
inline return_type_t<T1, T2> log_sum_exp(const T2& a, const T1& b) {
5254
if (a == NEGATIVE_INFTY) {
5355
return b;
@@ -91,6 +93,22 @@ inline auto log_sum_exp(const T& x) {
9193
});
9294
}
9395

96+
/**
97+
* Enables the vectorized application of the log_sum_exp function,
98+
* when the first and/or second arguments are containers.
99+
*
100+
* @tparam T1 type of first input
101+
* @tparam T2 type of second input
102+
* @param a First input
103+
* @param b Second input
104+
* @return log_sum_exp function applied to the two inputs.
105+
*/
106+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
107+
inline auto log_sum_exp(const T1& a, const T2& b) {
108+
return apply_scalar_binary(
109+
a, b, [](const auto& c, const auto& d) { return log_sum_exp(c, d); });
110+
}
111+
94112
} // namespace math
95113
} // namespace stan
96114

test/unit/math/mix/fun/log_sum_exp_test.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,16 @@ TEST(MathMixMatFun, logSumExp) {
9696
std::vector<double>(x2c.data(), x2c.data() + x2c.size())};
9797
stan::test::expect_ad(tols, f, ststx);
9898
}
99+
100+
TEST(mathMixScalFun, logSumExp_vec) {
101+
auto f = [](const auto& x1, const auto& x2) {
102+
using stan::math::log_sum_exp;
103+
return log_sum_exp(x1, x2);
104+
};
105+
106+
Eigen::VectorXd in1(2);
107+
in1 << 3, 1;
108+
Eigen::VectorXd in2(2);
109+
in2 << 0.5, 3.4;
110+
stan::test::expect_ad_vectorized_binary(f, in1, in2);
111+
}

0 commit comments

Comments
 (0)