|
7 | 7 | #include <stan/math/prim/err.hpp> |
8 | 8 | #include <stan/math/prim/functor/partials_propagator.hpp> |
9 | 9 | #include <stan/math/prim/fun/binomial_coefficient_log.hpp> |
10 | | -#include <stan/math/prim/fun/inv_logit.hpp> |
| 10 | +#include <stan/math/prim/fun/log_inv_logit.hpp> |
| 11 | +#include <stan/math/prim/fun/log1m_inv_logit.hpp> |
11 | 12 |
|
12 | 13 | namespace stan { |
13 | 14 | namespace math { |
@@ -60,18 +61,15 @@ return_type_t<T_prob_cl> binomial_logit_lpmf(const T_n_cl& n, const T_N_cl N, |
60 | 61 | = check_cl(function, "Probability parameter", alpha_val, "finite"); |
61 | 62 | auto alpha_finite = isfinite(alpha_val); |
62 | 63 |
|
63 | | - auto inv_logit_alpha = inv_logit(alpha_val); |
64 | | - auto inv_logit_neg_alpha = inv_logit(-alpha_val); |
65 | | - auto log_inv_logit_alpha = log(inv_logit_alpha); |
66 | | - auto log_inv_logit_neg_alpha = log(inv_logit_neg_alpha); |
| 64 | + auto log_inv_logit_alpha = log_inv_logit(alpha); |
| 65 | + auto log1m_inv_logit_alpha = log1m_inv_logit(alpha); |
67 | 66 | auto n_diff = N - n; |
68 | 67 | auto logp_expr1 = elt_multiply(n, log_inv_logit_alpha) |
69 | | - + elt_multiply(n_diff, log_inv_logit_neg_alpha); |
| 68 | + + elt_multiply(n_diff, log1m_inv_logit_alpha); |
70 | 69 | auto logp_expr |
71 | 70 | = static_select<include_summand<propto, T_n_cl, T_N_cl>::value>( |
72 | 71 | logp_expr1 + binomial_coefficient_log(N, n), logp_expr1); |
73 | | - auto alpha_deriv = elt_multiply(n, inv_logit_neg_alpha) |
74 | | - - elt_multiply(n_diff, inv_logit_alpha); |
| 72 | + auto alpha_deriv = n - elt_multiply(N, exp(log_inv_logit_alpha)); |
75 | 73 |
|
76 | 74 | matrix_cl<double> logp_cl; |
77 | 75 | matrix_cl<double> alpha_deriv_cl; |
|
0 commit comments