Skip to content

Commit c7cb0ea

Browse files
authored
Merge pull request #2950 from stan-dev/compound-funs
Cleanup more usage of compound functions throughout Math
2 parents efbc688 + cc8dc55 commit c7cb0ea

24 files changed

Lines changed: 62 additions & 49 deletions

stan/math/fwd/fun/log1m_inv_logit.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace math {
2020
template <typename T>
2121
inline fvar<T> log1m_inv_logit(const fvar<T>& x) {
2222
using std::exp;
23-
return fvar<T>(log1m_inv_logit(x.val_), -x.d_ / (1 + exp(-x.val_)));
23+
return fvar<T>(log1m_inv_logit(x.val_), -x.d_ * inv_logit(x.val_));
2424
}
2525

2626
} // namespace math

stan/math/fwd/fun/log1p_exp.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace math {
1212
template <typename T>
1313
inline fvar<T> log1p_exp(const fvar<T>& x) {
1414
using std::exp;
15-
return fvar<T>(log1p_exp(x.val_), x.d_ / (1 + exp(-x.val_)));
15+
return fvar<T>(log1p_exp(x.val_), x.d_ * inv_logit(x.val_));
1616
}
1717

1818
} // namespace math

stan/math/fwd/fun/log_inv_logit.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/fwd/meta.hpp>
55
#include <stan/math/fwd/core.hpp>
6+
#include <stan/math/prim/fun/inv_logit.hpp>
67
#include <stan/math/prim/fun/log_inv_logit.hpp>
78
#include <cmath>
89

@@ -12,7 +13,7 @@ namespace math {
1213
template <typename T>
1314
inline fvar<T> log_inv_logit(const fvar<T>& x) {
1415
using std::exp;
15-
return fvar<T>(log_inv_logit(x.val_), x.d_ / (1 + exp(x.val_)));
16+
return fvar<T>(log_inv_logit(x.val_), x.d_ * inv_logit(-x.val_));
1617
}
1718
} // namespace math
1819
} // namespace stan

stan/math/fwd/fun/log_sum_exp.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ template <typename T>
1818
inline fvar<T> log_sum_exp(const fvar<T>& x1, const fvar<T>& x2) {
1919
using std::exp;
2020
return fvar<T>(log_sum_exp(x1.val_, x2.val_),
21-
x1.d_ / (1 + exp(x2.val_ - x1.val_))
22-
+ x2.d_ / (exp(x1.val_ - x2.val_) + 1));
21+
x1.d_ * inv_logit(-(x2.val_ - x1.val_))
22+
+ x2.d_ * inv_logit(-(x1.val_ - x2.val_)));
2323
}
2424

2525
template <typename T>
@@ -28,7 +28,7 @@ inline fvar<T> log_sum_exp(double x1, const fvar<T>& x2) {
2828
if (x1 == NEGATIVE_INFTY) {
2929
return fvar<T>(x2.val_, x2.d_);
3030
}
31-
return fvar<T>(log_sum_exp(x1, x2.val_), x2.d_ / (exp(x1 - x2.val_) + 1));
31+
return fvar<T>(log_sum_exp(x1, x2.val_), x2.d_ * inv_logit(-(x1 - x2.val_)));
3232
}
3333

3434
template <typename T>

stan/math/opencl/kernel_generator/elt_function_cl.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(digamma,
298298
opencl_kernels::digamma_device_function)
299299
ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m, opencl_kernels::log1m_device_function)
300300
ADD_UNARY_FUNCTION_WITH_INCLUDES(log_inv_logit,
301+
opencl_kernels::log1p_exp_device_function,
301302
opencl_kernels::log_inv_logit_device_function)
302303
ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m_exp,
303304
opencl_kernels::log1m_exp_device_function)
@@ -317,7 +318,8 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_Phi, opencl_kernels::log1m_device_function,
317318
opencl_kernels::phi_device_function,
318319
opencl_kernels::inv_phi_device_function)
319320
ADD_UNARY_FUNCTION_WITH_INCLUDES(
320-
log1m_inv_logit, opencl_kernels::log1m_inv_logit_device_function)
321+
log1m_inv_logit, opencl_kernels::log1p_exp_device_function,
322+
opencl_kernels::log1m_inv_logit_device_function)
321323
ADD_UNARY_FUNCTION_WITH_INCLUDES(trigamma,
322324
opencl_kernels::trigamma_device_function)
323325
ADD_UNARY_FUNCTION_WITH_INCLUDES(

stan/math/opencl/kernels/device_functions/log1m_inv_logit.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ static const char* log1m_inv_logit_device_function
4040
*/
4141
inline double log1m_inv_logit(double x) {
4242
if (x > 0.0) {
43-
return -x - log1p(exp(-x)); // prevent underflow
43+
return -x - log1p_exp(-x); // prevent underflow
4444
}
45-
return -log1p(exp(x));
45+
return -log1p_exp(x);
4646
}
4747
// \cond
4848
) "\n#endif\n"; // NOLINT

stan/math/opencl/kernels/device_functions/log_inv_logit.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ static const char* log_inv_logit_device_function
2727
*/
2828
double log_inv_logit(double x) {
2929
if (x < 0.0) {
30-
return x - log1p(exp(x)); // prevent underflow
30+
return x - log1p_exp(x); // prevent underflow
3131
}
32-
return -log1p(exp(-x));
32+
return -log1p_exp(-x);
3333
}
3434
// \cond
3535
) "\n#endif\n"; // NOLINT

stan/math/opencl/prim/binomial_logit_lpmf.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
#include <stan/math/prim/err.hpp>
88
#include <stan/math/prim/functor/partials_propagator.hpp>
99
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
10-
#include <stan/math/prim/fun/inv_logit.hpp>
10+
#include <stan/math/prim/fun/log_inv_logit.hpp>
11+
#include <stan/math/prim/fun/log1m_inv_logit.hpp>
1112

1213
namespace stan {
1314
namespace math {
@@ -60,18 +61,15 @@ return_type_t<T_prob_cl> binomial_logit_lpmf(const T_n_cl& n, const T_N_cl N,
6061
= check_cl(function, "Probability parameter", alpha_val, "finite");
6162
auto alpha_finite = isfinite(alpha_val);
6263

63-
auto inv_logit_alpha = inv_logit(alpha_val);
64-
auto inv_logit_neg_alpha = inv_logit(-alpha_val);
65-
auto log_inv_logit_alpha = log(inv_logit_alpha);
66-
auto log_inv_logit_neg_alpha = log(inv_logit_neg_alpha);
64+
auto log_inv_logit_alpha = log_inv_logit(alpha_val);
65+
auto log1m_inv_logit_alpha = log1m_inv_logit(alpha_val);
6766
auto n_diff = N - n;
6867
auto logp_expr1 = elt_multiply(n, log_inv_logit_alpha)
69-
+ elt_multiply(n_diff, log_inv_logit_neg_alpha);
68+
+ elt_multiply(n_diff, log1m_inv_logit_alpha);
7069
auto logp_expr
7170
= static_select<include_summand<propto, T_n_cl, T_N_cl>::value>(
7271
logp_expr1 + binomial_coefficient_log(N, n), logp_expr1);
73-
auto alpha_deriv = elt_multiply(n, inv_logit_neg_alpha)
74-
- elt_multiply(n_diff, inv_logit_alpha);
72+
auto alpha_deriv = n - elt_multiply(N, exp(log_inv_logit_alpha));
7573

7674
matrix_cl<double> logp_cl;
7775
matrix_cl<double> alpha_deriv_cl;

stan/math/opencl/prim/logistic_lpdf.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/prim/meta.hpp>
88
#include <stan/math/prim/err.hpp>
99
#include <stan/math/prim/fun/digamma.hpp>
10+
#include <stan/math/prim/fun/log1p_exp.hpp>
1011
#include <stan/math/prim/fun/lgamma.hpp>
1112
#include <stan/math/prim/fun/max_size.hpp>
1213
#include <stan/math/prim/functor/partials_propagator.hpp>
@@ -75,7 +76,7 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl> logistic_lpdf(
7576
auto y_minus_mu = y_val - mu_val;
7677
auto y_minus_mu_div_sigma = elt_multiply(y_minus_mu, inv_sigma);
7778

78-
auto logp1 = -y_minus_mu_div_sigma - 2.0 * log1p(exp(-y_minus_mu_div_sigma));
79+
auto logp1 = -y_minus_mu_div_sigma - 2.0 * log1p_exp(-y_minus_mu_div_sigma);
7980
auto logp_expr
8081
= colwise_sum(static_select<include_summand<propto, T_scale_cl>::value>(
8182
logp1 - log(sigma_val), logp1));

stan/math/opencl/prim/pareto_lcdf.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ return_type_t<T_y_cl, T_scale_cl, T_shape_cl> pareto_lcdf(
6868

6969
auto log_quot = log(elt_divide(y_min_val, y_val));
7070
auto exp_prod = exp(elt_multiply(alpha_val, log_quot));
71-
auto lcdf_expr = colwise_sum(log(1.0 - exp_prod));
71+
// TODO(Andrew) Further simplify derivatives and log1m_exp below
72+
auto lcdf_expr = colwise_sum(log1m(exp_prod));
7273

7374
auto common_deriv = elt_divide(exp_prod, 1.0 - exp_prod);
7475

0 commit comments

Comments
 (0)