Skip to content

Commit c8f2afd

Browse files
committed
log_inv_logit
1 parent 9789fb8 commit c8f2afd

2 files changed

Lines changed: 12 additions & 15 deletions

File tree

stan/math/opencl/prim/binomial_logit_lpmf.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
#include <stan/math/prim/err.hpp>
88
#include <stan/math/prim/functor/partials_propagator.hpp>
99
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
10-
#include <stan/math/prim/fun/inv_logit.hpp>
10+
#include <stan/math/prim/fun/log_inv_logit.hpp>
11+
#include <stan/math/prim/fun/log1m_inv_logit.hpp>
1112

1213
namespace stan {
1314
namespace math {
@@ -60,18 +61,15 @@ return_type_t<T_prob_cl> binomial_logit_lpmf(const T_n_cl& n, const T_N_cl N,
6061
= check_cl(function, "Probability parameter", alpha_val, "finite");
6162
auto alpha_finite = isfinite(alpha_val);
6263

63-
auto inv_logit_alpha = inv_logit(alpha_val);
64-
auto inv_logit_neg_alpha = inv_logit(-alpha_val);
65-
auto log_inv_logit_alpha = log(inv_logit_alpha);
66-
auto log_inv_logit_neg_alpha = log(inv_logit_neg_alpha);
64+
auto log_inv_logit_alpha = log_inv_logit(alpha);
65+
auto log1m_inv_logit_alpha = log1m_inv_logit(alpha);
6766
auto n_diff = N - n;
6867
auto logp_expr1 = elt_multiply(n, log_inv_logit_alpha)
69-
+ elt_multiply(n_diff, log_inv_logit_neg_alpha);
68+
+ elt_multiply(n_diff, log1m_inv_logit_alpha);
7069
auto logp_expr
7170
= static_select<include_summand<propto, T_n_cl, T_N_cl>::value>(
7271
logp_expr1 + binomial_coefficient_log(N, n), logp_expr1);
73-
auto alpha_deriv = elt_multiply(n, inv_logit_neg_alpha)
74-
- elt_multiply(n_diff, inv_logit_alpha);
72+
auto alpha_deriv = n - elt_multiply(N, exp(log_inv_logit_alpha));
7573

7674
matrix_cl<double> logp_cl;
7775
matrix_cl<double> alpha_deriv_cl;

stan/math/prim/fun/prob_constrain.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
#define STAN_MATH_PRIM_FUN_PROB_CONSTRAIN_HPP
33

44
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/fun/inv_logit.hpp>
6-
#include <stan/math/prim/fun/log.hpp>
7-
#include <stan/math/prim/fun/log1m.hpp>
5+
#include <stan/math/prim/fun/log_inv_logit.hpp>
6+
#include <stan/math/prim/fun/exp.hpp>
7+
#include <stan/math/prim/fun/log1m_inv_logit.hpp>
88
#include <cmath>
99

1010
namespace stan {
@@ -49,10 +49,9 @@ inline T prob_constrain(const T& x) {
4949
*/
5050
template <typename T>
5151
inline T prob_constrain(const T& x, T& lp) {
52-
using std::log;
53-
T inv_logit_x = inv_logit(x);
54-
lp += log(inv_logit_x) + log1m(inv_logit_x);
55-
return inv_logit_x;
52+
T log_inv_logit_x = log_inv_logit(x);
53+
lp += log_inv_logit_x + log1m_inv_logit(x);
54+
return exp(log_inv_logit_x);
5655
}
5756

5857
/**

0 commit comments

Comments
 (0)