Skip to content

Commit eb3b5d7

Browse files
authored
Merge pull request #2945 from stan-dev/binomial-logit-numerics
Improve numerical stability of binomial_logit_lpmf
2 parents 9f2689e + 2a81187 commit eb3b5d7

2 files changed

Lines changed: 32 additions & 23 deletions

File tree

stan/math/prim/prob/binomial_logit_lpmf.hpp

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
6-
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
7-
#include <stan/math/prim/fun/as_array_or_scalar.hpp>
86
#include <stan/math/prim/fun/as_value_column_array_or_scalar.hpp>
97
#include <stan/math/prim/fun/binomial_coefficient_log.hpp>
10-
#include <stan/math/prim/fun/inc_beta.hpp>
11-
#include <stan/math/prim/fun/inv_logit.hpp>
12-
#include <stan/math/prim/fun/lbeta.hpp>
13-
#include <stan/math/prim/fun/log.hpp>
8+
#include <stan/math/prim/fun/log_inv_logit.hpp>
9+
#include <stan/math/prim/fun/log1m_inv_logit.hpp>
10+
#include <stan/math/prim/fun/exp.hpp>
1411
#include <stan/math/prim/fun/max_size.hpp>
1512
#include <stan/math/prim/fun/size.hpp>
1613
#include <stan/math/prim/fun/size_zero.hpp>
@@ -66,33 +63,22 @@ return_type_t<T_prob> binomial_logit_lpmf(const T_n& n, const T_N& N,
6663
if (!include_summand<propto, T_prob>::value) {
6764
return 0.0;
6865
}
69-
const auto& inv_logit_alpha
70-
= to_ref_if<!is_constant_all<T_prob>::value>(inv_logit(alpha_val));
71-
const auto& inv_logit_neg_alpha
72-
= to_ref_if<!is_constant_all<T_prob>::value>(inv_logit(-alpha_val));
66+
const auto& log_inv_logit_alpha
67+
= to_ref_if<!is_constant_all<T_prob>::value>(log_inv_logit(alpha_val));
68+
const auto& log1m_inv_logit_alpha
69+
= to_ref_if<!is_constant_all<T_prob>::value>(log1m_inv_logit(alpha_val));
7370

7471
size_t maximum_size = max_size(n, N, alpha);
75-
const auto& log_inv_logit_alpha = log(inv_logit_alpha);
76-
const auto& log_inv_logit_neg_alpha = log(inv_logit_neg_alpha);
7772
T_partials_return logp = sum(n_val * log_inv_logit_alpha
78-
+ (N_val - n_val) * log_inv_logit_neg_alpha);
73+
+ (N_val - n_val) * log1m_inv_logit_alpha);
7974
if (include_summand<propto, T_n, T_N>::value) {
8075
logp += sum(binomial_coefficient_log(N_val, n_val)) * maximum_size
8176
/ max_size(n, N);
8277
}
8378

8479
auto ops_partials = make_partials_propagator(alpha_ref);
8580
if (!is_constant_all<T_prob>::value) {
86-
if (is_vector<T_prob>::value) {
87-
edge<0>(ops_partials).partials_
88-
= n_val * inv_logit_neg_alpha - (N_val - n_val) * inv_logit_alpha;
89-
} else {
90-
T_partials_return sum_n = sum(n_val) * maximum_size / math::size(n);
91-
partials<0>(ops_partials)[0] = forward_as<T_partials_return>(
92-
sum_n * inv_logit_neg_alpha
93-
- (sum(N_val) * maximum_size / math::size(N) - sum_n)
94-
* inv_logit_alpha);
95-
}
81+
edge<0>(ops_partials).partials_ = n_val - N_val * exp(log_inv_logit_alpha);
9682
}
9783

9884
return ops_partials.build(logp);
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <stan/math/mix.hpp>
2+
#include <test/unit/math/test_ad.hpp>
3+
4+
TEST(mathMixScalFun, binomial_logit_lpmf) {
5+
auto f = [](const auto n, const auto N) {
6+
return [=](const auto& alpha) {
7+
return stan::math::binomial_logit_lpmf(n, N, alpha);
8+
};
9+
};
10+
11+
Eigen::VectorXd alpha = Eigen::VectorXd::Random(3);
12+
std::vector<int> n_arr{1, 4, 5};
13+
std::vector<int> N_arr{10, 45, 25};
14+
15+
stan::test::expect_ad(f(5, 25), 2.11);
16+
stan::test::expect_ad(f(5, 25), alpha);
17+
stan::test::expect_ad(f(n_arr, 25), alpha);
18+
stan::test::expect_ad(f(n_arr, N_arr), alpha);
19+
stan::test::expect_ad(f(n_arr, 10), 2.11);
20+
stan::test::expect_ad(f(n_arr, N_arr), 2.11);
21+
stan::test::expect_ad(f(5, N_arr), 2.11);
22+
stan::test::expect_ad(f(5, N_arr), alpha);
23+
}

0 commit comments

Comments
 (0)