Skip to content

Commit 9789fb8

Browse files
committed
Update log1p_exp usage
1 parent eb3b5d7 commit 9789fb8

8 files changed

Lines changed: 21 additions & 16 deletions

File tree

stan/math/opencl/kernel_generator/elt_function_cl.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(digamma,
298298
opencl_kernels::digamma_device_function)
299299
ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m, opencl_kernels::log1m_device_function)
300300
ADD_UNARY_FUNCTION_WITH_INCLUDES(log_inv_logit,
301+
opencl_kernels::log1p_exp_device_function,
301302
opencl_kernels::log_inv_logit_device_function)
302303
ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m_exp,
303304
opencl_kernels::log1m_exp_device_function)
@@ -316,8 +317,9 @@ ADD_UNARY_FUNCTION_WITH_INCLUDES(Phi_approx,
316317
ADD_UNARY_FUNCTION_WITH_INCLUDES(inv_Phi, opencl_kernels::log1m_device_function,
317318
opencl_kernels::phi_device_function,
318319
opencl_kernels::inv_phi_device_function)
319-
ADD_UNARY_FUNCTION_WITH_INCLUDES(
320-
log1m_inv_logit, opencl_kernels::log1m_inv_logit_device_function)
320+
ADD_UNARY_FUNCTION_WITH_INCLUDES(log1m_inv_logit,
321+
opencl_kernels::log1p_exp_device_function,
322+
opencl_kernels::log1m_inv_logit_device_function)
321323
ADD_UNARY_FUNCTION_WITH_INCLUDES(trigamma,
322324
opencl_kernels::trigamma_device_function)
323325
ADD_UNARY_FUNCTION_WITH_INCLUDES(

stan/math/opencl/kernels/device_functions/log1m_inv_logit.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ static const char* log1m_inv_logit_device_function
4040
*/
4141
inline double log1m_inv_logit(double x) {
4242
if (x > 0.0) {
43-
return -x - log1p(exp(-x)); // prevent underflow
43+
return -x - log1p_exp(-x); // prevent underflow
4444
}
45-
return -log1p(exp(x));
45+
return -log1p_exp(x);
4646
}
4747
// \cond
4848
) "\n#endif\n"; // NOLINT

stan/math/opencl/kernels/device_functions/log_inv_logit.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ static const char* log_inv_logit_device_function
2727
*/
2828
double log_inv_logit(double x) {
2929
if (x < 0.0) {
30-
return x - log1p(exp(x)); // prevent underflow
30+
return x - log1p_exp(x); // prevent underflow
3131
}
32-
return -log1p(exp(-x));
32+
return -log1p_exp(-x);
3333
}
3434
// \cond
3535
) "\n#endif\n"; // NOLINT

stan/math/opencl/prim/logistic_lpdf.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/prim/meta.hpp>
88
#include <stan/math/prim/err.hpp>
99
#include <stan/math/prim/fun/digamma.hpp>
10+
#include <stan/math/prim/fun/log1p_exp.hpp>
1011
#include <stan/math/prim/fun/lgamma.hpp>
1112
#include <stan/math/prim/fun/max_size.hpp>
1213
#include <stan/math/prim/functor/partials_propagator.hpp>
@@ -75,7 +76,7 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl> logistic_lpdf(
7576
auto y_minus_mu = y_val - mu_val;
7677
auto y_minus_mu_div_sigma = elt_multiply(y_minus_mu, inv_sigma);
7778

78-
auto logp1 = -y_minus_mu_div_sigma - 2.0 * log1p(exp(-y_minus_mu_div_sigma));
79+
auto logp1 = -y_minus_mu_div_sigma - 2.0 * log1p_exp(-y_minus_mu_div_sigma);
7980
auto logp_expr
8081
= colwise_sum(static_select<include_summand<propto, T_scale_cl>::value>(
8182
logp1 - log(sigma_val), logp1));

stan/math/prim/fun/log1m_inv_logit.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/exp.hpp>
6-
#include <stan/math/prim/fun/log1p.hpp>
6+
#include <stan/math/prim/fun/log1p_exp.hpp>
77
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
88
#include <cmath>
99

@@ -36,9 +36,9 @@ namespace math {
3636
inline double log1m_inv_logit(double u) {
3737
using std::exp;
3838
if (u > 0.0) {
39-
return -u - log1p(exp(-u)); // prevent underflow
39+
return -u - log1p_exp(-u); // prevent underflow
4040
}
41-
return -log1p(exp(u));
41+
return -log1p_exp(u);
4242
}
4343

4444
/**

stan/math/prim/fun/log_inv_logit.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/exp.hpp>
6-
#include <stan/math/prim/fun/log1p.hpp>
6+
#include <stan/math/prim/fun/log1p_exp.hpp>
77
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
88
#include <cmath>
99

@@ -34,9 +34,9 @@ namespace math {
3434
inline double log_inv_logit(double u) {
3535
using std::exp;
3636
if (u < 0.0) {
37-
return u - log1p(exp(u)); // prevent underflow
37+
return u - log1p_exp(u); // prevent underflow
3838
}
39-
return -log1p(exp(-u));
39+
return -log1p_exp(-u);
4040
}
4141

4242
/**

stan/math/prim/prob/logistic_lpdf.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <stan/math/prim/fun/exp.hpp>
1010
#include <stan/math/prim/fun/log.hpp>
1111
#include <stan/math/prim/fun/log1p.hpp>
12+
#include <stan/math/prim/fun/log1p_exp.hpp>
1213
#include <stan/math/prim/fun/max_size.hpp>
1314
#include <stan/math/prim/fun/size.hpp>
1415
#include <stan/math/prim/fun/size_zero.hpp>
@@ -63,7 +64,7 @@ return_type_t<T_y, T_loc, T_scale> logistic_lpdf(const T_y& y, const T_loc& mu,
6364

6465
size_t N = max_size(y, mu, sigma);
6566
T_partials_return logp = -sum(y_minus_mu_div_sigma)
66-
- 2.0 * sum(log1p(exp(-y_minus_mu_div_sigma)));
67+
- 2.0 * sum(log1p_exp(-y_minus_mu_div_sigma));
6768
if (include_summand<propto, T_scale>::value) {
6869
logp -= sum(log(sigma_val)) * N / math::size(sigma);
6970
}

stan/math/prim/prob/neg_binomial_2_log_glm_lpmf.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <stan/math/prim/fun/exp.hpp>
1111
#include <stan/math/prim/fun/lgamma.hpp>
1212
#include <stan/math/prim/fun/log.hpp>
13+
#include <stan/math/prim/fun/log1p_exp.hpp>
1314
#include <stan/math/prim/fun/multiply_log.hpp>
1415
#include <stan/math/prim/fun/scalar_seq_view.hpp>
1516
#include <stan/math/prim/fun/size.hpp>
@@ -153,8 +154,8 @@ return_type_t<T_x, T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
153154
T_precision_val log_phi = log(phi_arr);
154155
Array<T_partials_return, Dynamic, 1> logsumexp_theta_logphi
155156
= (theta > log_phi)
156-
.select(theta + log1p(exp(log_phi - theta)),
157-
log_phi + log1p(exp(theta - log_phi)));
157+
.select(theta + log1p_exp(log_phi - theta),
158+
log_phi + log1p_exp(theta - log_phi));
158159

159160
T_sum_val y_plus_phi = y_arr + phi_arr;
160161

0 commit comments

Comments
 (0)