|
3 | 3 |
|
4 | 4 | #include <stan/math/prim/meta.hpp> |
5 | 5 | #include <stan/math/prim/err.hpp> |
6 | | -#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp> |
7 | | -#include <stan/math/prim/fun/as_array_or_scalar.hpp> |
8 | 6 | #include <stan/math/prim/fun/as_value_column_array_or_scalar.hpp> |
9 | 7 | #include <stan/math/prim/fun/binomial_coefficient_log.hpp> |
10 | | -#include <stan/math/prim/fun/inc_beta.hpp> |
11 | | -#include <stan/math/prim/fun/inv_logit.hpp> |
12 | | -#include <stan/math/prim/fun/lbeta.hpp> |
13 | | -#include <stan/math/prim/fun/log.hpp> |
| 8 | +#include <stan/math/prim/fun/log_inv_logit.hpp> |
| 9 | +#include <stan/math/prim/fun/log1m_inv_logit.hpp> |
| 10 | +#include <stan/math/prim/fun/exp.hpp> |
14 | 11 | #include <stan/math/prim/fun/max_size.hpp> |
15 | 12 | #include <stan/math/prim/fun/size.hpp> |
16 | 13 | #include <stan/math/prim/fun/size_zero.hpp> |
@@ -66,33 +63,22 @@ return_type_t<T_prob> binomial_logit_lpmf(const T_n& n, const T_N& N, |
66 | 63 | if (!include_summand<propto, T_prob>::value) { |
67 | 64 | return 0.0; |
68 | 65 | } |
69 | | - const auto& inv_logit_alpha |
70 | | - = to_ref_if<!is_constant_all<T_prob>::value>(inv_logit(alpha_val)); |
71 | | - const auto& inv_logit_neg_alpha |
72 | | - = to_ref_if<!is_constant_all<T_prob>::value>(inv_logit(-alpha_val)); |
| 66 | + const auto& log_inv_logit_alpha |
| 67 | + = to_ref_if<!is_constant_all<T_prob>::value>(log_inv_logit(alpha_val)); |
| 68 | + const auto& log1m_inv_logit_alpha |
| 69 | + = to_ref_if<!is_constant_all<T_prob>::value>(log1m_inv_logit(alpha_val)); |
73 | 70 |
|
74 | 71 | size_t maximum_size = max_size(n, N, alpha); |
75 | | - const auto& log_inv_logit_alpha = log(inv_logit_alpha); |
76 | | - const auto& log_inv_logit_neg_alpha = log(inv_logit_neg_alpha); |
77 | 72 | T_partials_return logp = sum(n_val * log_inv_logit_alpha |
78 | | - + (N_val - n_val) * log_inv_logit_neg_alpha); |
| 73 | + + (N_val - n_val) * log1m_inv_logit_alpha); |
79 | 74 | if (include_summand<propto, T_n, T_N>::value) { |
80 | 75 | logp += sum(binomial_coefficient_log(N_val, n_val)) * maximum_size |
81 | 76 | / max_size(n, N); |
82 | 77 | } |
83 | 78 |
|
84 | 79 | auto ops_partials = make_partials_propagator(alpha_ref); |
85 | 80 | if (!is_constant_all<T_prob>::value) { |
86 | | - if (is_vector<T_prob>::value) { |
87 | | - edge<0>(ops_partials).partials_ |
88 | | - = n_val * inv_logit_neg_alpha - (N_val - n_val) * inv_logit_alpha; |
89 | | - } else { |
90 | | - T_partials_return sum_n = sum(n_val) * maximum_size / math::size(n); |
91 | | - partials<0>(ops_partials)[0] = forward_as<T_partials_return>( |
92 | | - sum_n * inv_logit_neg_alpha |
93 | | - - (sum(N_val) * maximum_size / math::size(N) - sum_n) |
94 | | - * inv_logit_alpha); |
95 | | - } |
| 81 | + edge<0>(ops_partials).partials_ = n_val - N_val * exp(log_inv_logit_alpha); |
96 | 82 | } |
97 | 83 |
|
98 | 84 | return ops_partials.build(logp); |
|
0 commit comments