Skip to content

Commit 9f2689e

Browse files
authored
Merge pull request #2784 from andrjohns/issue-2783-bernoulli-cdf-stable
Improve Numerical Stability of Bernoulli CDF functions
2 parents c001c71 + 47c4f9a commit 9f2689e

6 files changed

Lines changed: 110 additions & 98 deletions

File tree

stan/math/prim/prob/bernoulli_cdf.hpp

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
6-
#include <stan/math/prim/fun/constants.hpp>
7-
#include <stan/math/prim/fun/max_size.hpp>
8-
#include <stan/math/prim/fun/scalar_seq_view.hpp>
6+
#include <stan/math/prim/fun/any.hpp>
7+
#include <stan/math/prim/fun/select.hpp>
98
#include <stan/math/prim/fun/size.hpp>
109
#include <stan/math/prim/fun/size_zero.hpp>
11-
#include <stan/math/prim/fun/value_of.hpp>
1210
#include <stan/math/prim/functor/partials_propagator.hpp>
1311

1412
namespace stan {
@@ -36,50 +34,30 @@ return_type_t<T_prob> bernoulli_cdf(const T_n& n, const T_prob& theta) {
3634
check_consistent_sizes(function, "Random variable", n,
3735
"Probability parameter", theta);
3836
T_theta_ref theta_ref = theta;
39-
check_bounded(function, "Probability parameter", value_of(theta_ref), 0.0,
40-
1.0);
37+
const auto& n_arr = as_array_or_scalar(n);
38+
const auto& theta_arr = as_value_column_array_or_scalar(theta_ref);
39+
check_bounded(function, "Probability parameter", theta_arr, 0.0, 1.0);
4140

4241
if (size_zero(n, theta)) {
4342
return 1.0;
4443
}
4544

46-
T_partials_return P(1.0);
4745
auto ops_partials = make_partials_propagator(theta_ref);
4846

49-
scalar_seq_view<T_n> n_vec(n);
50-
scalar_seq_view<T_theta_ref> theta_vec(theta_ref);
51-
size_t max_size_seq_view = max_size(n, theta);
52-
5347
// Explicit return for extreme values
5448
// The gradients are technically ill-defined, but treated as zero
55-
for (size_t i = 0; i < stan::math::size(n); i++) {
56-
if (n_vec.val(i) < 0) {
57-
return ops_partials.build(0.0);
58-
}
49+
if (any(n_arr < 0)) {
50+
return ops_partials.build(0.0);
5951
}
52+
const auto& log1m_theta = select(theta_arr == 1, 0.0, log1m(theta_arr));
53+
const auto& P1 = select(n_arr == 0, log1m_theta, 0.0);
6054

61-
for (size_t i = 0; i < max_size_seq_view; i++) {
62-
// Explicit results for extreme values
63-
// The gradients are technically ill-defined, but treated as zero
64-
if (n_vec.val(i) >= 1) {
65-
continue;
66-
}
67-
68-
const T_partials_return Pi = 1 - theta_vec.val(i);
69-
70-
P *= Pi;
71-
72-
if (!is_constant_all<T_prob>::value) {
73-
partials<0>(ops_partials)[i] += -1 / Pi;
74-
}
75-
}
55+
T_partials_return P = sum(P1);
7656

7757
if (!is_constant_all<T_prob>::value) {
78-
for (size_t i = 0; i < stan::math::size(theta); ++i) {
79-
partials<0>(ops_partials)[i] *= P;
80-
}
58+
partials<0>(ops_partials) = select(n_arr == 0, -exp(P - P1), 0.0);
8159
}
82-
return ops_partials.build(P);
60+
return ops_partials.build(exp(P));
8361
}
8462

8563
} // namespace math

stan/math/prim/prob/bernoulli_lccdf.hpp

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,13 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/any.hpp>
67
#include <stan/math/prim/fun/constants.hpp>
78
#include <stan/math/prim/fun/inv.hpp>
89
#include <stan/math/prim/fun/log.hpp>
9-
#include <stan/math/prim/fun/max_size.hpp>
10-
#include <stan/math/prim/fun/scalar_seq_view.hpp>
11-
#include <stan/math/prim/fun/size.hpp>
10+
#include <stan/math/prim/fun/select.hpp>
1211
#include <stan/math/prim/fun/size_zero.hpp>
13-
#include <stan/math/prim/fun/value_of.hpp>
1412
#include <stan/math/prim/functor/partials_propagator.hpp>
15-
#include <cmath>
1613

1714
namespace stan {
1815
namespace math {
@@ -33,50 +30,38 @@ template <typename T_n, typename T_prob,
3330
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
3431
T_n, T_prob>* = nullptr>
3532
return_type_t<T_prob> bernoulli_lccdf(const T_n& n, const T_prob& theta) {
36-
using T_partials_return = partials_return_t<T_n, T_prob>;
3733
using T_theta_ref = ref_type_t<T_prob>;
38-
using std::log;
3934
static const char* function = "bernoulli_lccdf";
4035
check_consistent_sizes(function, "Random variable", n,
4136
"Probability parameter", theta);
4237
T_theta_ref theta_ref = theta;
43-
check_bounded(function, "Probability parameter", value_of(theta_ref), 0.0,
44-
1.0);
38+
const auto& n_arr = as_array_or_scalar(n);
39+
const auto& theta_arr = as_value_column_array_or_scalar(theta_ref);
40+
check_bounded(function, "Probability parameter", theta_arr, 0.0, 1.0);
4541

4642
if (size_zero(n, theta)) {
4743
return 0.0;
4844
}
4945

50-
T_partials_return P(0.0);
5146
auto ops_partials = make_partials_propagator(theta_ref);
5247

53-
scalar_seq_view<T_n> n_vec(n);
54-
scalar_seq_view<T_theta_ref> theta_vec(theta_ref);
55-
size_t max_size_seq_view = max_size(n, theta);
56-
5748
// Explicit return for extreme values
5849
// The gradients are technically ill-defined, but treated as zero
59-
for (size_t i = 0; i < stan::math::size(n); i++) {
60-
const double n_dbl = n_vec.val(i);
61-
if (n_dbl < 0) {
62-
return ops_partials.build(0.0);
63-
}
64-
if (n_dbl >= 1) {
65-
return ops_partials.build(NEGATIVE_INFTY);
66-
}
50+
if (any(n_arr < 0)) {
51+
return ops_partials.build(0.0);
52+
} else if (any(n_arr >= 1)) {
53+
return ops_partials.build(NEGATIVE_INFTY);
6754
}
6855

69-
for (size_t i = 0; i < max_size_seq_view; i++) {
70-
const T_partials_return Pi = theta_vec.val(i);
71-
72-
P += log(Pi);
56+
size_t theta_size = math::size(theta_arr);
57+
size_t n_size = math::size(n_arr);
58+
double broadcast_n = theta_size == n_size ? 1 : n_size;
7359

74-
if (!is_constant_all<T_prob>::value) {
75-
partials<0>(ops_partials)[i] += inv(Pi);
76-
}
60+
if (!is_constant_all<T_prob>::value) {
61+
partials<0>(ops_partials) = inv(theta_arr) * broadcast_n;
7762
}
7863

79-
return ops_partials.build(P);
64+
return ops_partials.build(sum(log(theta_arr)) * broadcast_n);
8065
}
8166

8267
} // namespace math

stan/math/prim/prob/bernoulli_lcdf.hpp

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,11 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/any.hpp>
67
#include <stan/math/prim/fun/constants.hpp>
7-
#include <stan/math/prim/fun/inv.hpp>
8-
#include <stan/math/prim/fun/log.hpp>
9-
#include <stan/math/prim/fun/max_size.hpp>
10-
#include <stan/math/prim/fun/scalar_seq_view.hpp>
11-
#include <stan/math/prim/fun/size.hpp>
8+
#include <stan/math/prim/fun/select.hpp>
129
#include <stan/math/prim/fun/size_zero.hpp>
13-
#include <stan/math/prim/fun/value_of.hpp>
1410
#include <stan/math/prim/functor/partials_propagator.hpp>
15-
#include <cmath>
1611

1712
namespace stan {
1813
namespace math {
@@ -33,52 +28,34 @@ template <typename T_n, typename T_prob,
3328
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
3429
T_n, T_prob>* = nullptr>
3530
return_type_t<T_prob> bernoulli_lcdf(const T_n& n, const T_prob& theta) {
36-
using T_partials_return = partials_return_t<T_n, T_prob>;
3731
using T_theta_ref = ref_type_t<T_prob>;
38-
using std::log;
3932
static const char* function = "bernoulli_lcdf";
4033
check_consistent_sizes(function, "Random variable", n,
4134
"Probability parameter", theta);
4235
T_theta_ref theta_ref = theta;
43-
check_bounded(function, "Probability parameter", value_of(theta_ref), 0.0,
44-
1.0);
36+
const auto& n_arr = as_array_or_scalar(n);
37+
const auto& theta_arr = as_value_column_array_or_scalar(theta_ref);
38+
check_bounded(function, "Probability parameter", theta_arr, 0.0, 1.0);
4539

4640
if (size_zero(n, theta)) {
4741
return 0.0;
4842
}
4943

50-
T_partials_return P(0.0);
5144
auto ops_partials = make_partials_propagator(theta_ref);
5245

53-
scalar_seq_view<T_n> n_vec(n);
54-
scalar_seq_view<T_theta_ref> theta_vec(theta_ref);
55-
size_t max_size_seq_view = max_size(n, theta);
56-
5746
// Explicit return for extreme values
5847
// The gradients are technically ill-defined, but treated as zero
59-
for (size_t i = 0; i < stan::math::size(n); i++) {
60-
if (n_vec.val(i) < 0) {
61-
return ops_partials.build(NEGATIVE_INFTY);
62-
}
48+
if (any(n_arr < 0)) {
49+
return ops_partials.build(NEGATIVE_INFTY);
6350
}
6451

65-
for (size_t i = 0; i < max_size_seq_view; i++) {
66-
// Explicit results for extreme values
67-
// The gradients are technically ill-defined, but treated as zero
68-
if (n_vec.val(i) >= 1) {
69-
continue;
70-
}
71-
72-
const T_partials_return Pi = 1 - theta_vec.val(i);
73-
74-
P += log(Pi);
52+
const auto& log1m_theta = select(theta_arr == 1, 0.0, log1m(theta_arr));
7553

76-
if (!is_constant_all<T_prob>::value) {
77-
partials<0>(ops_partials)[i] -= inv(Pi);
78-
}
54+
if (!is_constant_all<T_prob>::value) {
55+
partials<0>(ops_partials) = select(n_arr == 0, -exp(-log1m_theta), 0.0);
7956
}
8057

81-
return ops_partials.build(P);
58+
return ops_partials.build(sum(select(n_arr == 0, log1m_theta, 0.0)));
8259
}
8360

8461
} // namespace math
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
#include <limits>
3+
4+
TEST(mathMixScalFun, bernoulliCDF) {
5+
// bind integer arg because can't autodiff through
6+
auto f = [](const auto& x1) {
7+
return [=](const auto& x2) { return stan::math::bernoulli_cdf(x1, x2); };
8+
};
9+
stan::test::expect_ad(f(0), 0.1);
10+
stan::test::expect_ad(f(0), std::numeric_limits<double>::quiet_NaN());
11+
stan::test::expect_ad(f(1), 0.5);
12+
stan::test::expect_ad(f(1), std::numeric_limits<double>::quiet_NaN());
13+
stan::test::expect_ad(f(1), 0.2);
14+
15+
std::vector<int> std_in1{0, 1};
16+
Eigen::VectorXd in2(2);
17+
in2 << 0.5, 0.9;
18+
19+
stan::test::expect_ad(f(std_in1), 0.2);
20+
stan::test::expect_ad(f(std_in1), std::numeric_limits<double>::quiet_NaN());
21+
stan::test::expect_ad(f(1), in2);
22+
stan::test::expect_ad(f(std_in1), in2);
23+
stan::test::expect_ad(f(std_in1), std::numeric_limits<double>::quiet_NaN());
24+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
#include <limits>
3+
4+
TEST(mathMixScalFun, bernoulliLCCDF) {
5+
// bind integer arg because can't autodiff through
6+
auto f = [](const auto& x1) {
7+
return [=](const auto& x2) { return stan::math::bernoulli_lccdf(x1, x2); };
8+
};
9+
stan::test::expect_ad(f(0), 0.1);
10+
stan::test::expect_ad(f(0), std::numeric_limits<double>::quiet_NaN());
11+
stan::test::expect_ad(f(1), 0.5);
12+
stan::test::expect_ad(f(1), std::numeric_limits<double>::quiet_NaN());
13+
stan::test::expect_ad(f(1), 0.2);
14+
15+
std::vector<int> std_in1{0, 1};
16+
Eigen::VectorXd in2(2);
17+
in2 << 0.5, 0.9;
18+
19+
stan::test::expect_ad(f(std_in1), 0.2);
20+
stan::test::expect_ad(f(std_in1), std::numeric_limits<double>::quiet_NaN());
21+
stan::test::expect_ad(f(1), in2);
22+
stan::test::expect_ad(f(std_in1), in2);
23+
stan::test::expect_ad(f(std_in1), std::numeric_limits<double>::quiet_NaN());
24+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
#include <limits>
3+
4+
TEST(mathMixScalFun, bernoulliLCDF) {
5+
// bind integer arg because can't autodiff through
6+
auto f = [](const auto& x1) {
7+
return [=](const auto& x2) { return stan::math::bernoulli_lcdf(x1, x2); };
8+
};
9+
stan::test::expect_ad(f(0), 0.1);
10+
stan::test::expect_ad(f(0), std::numeric_limits<double>::quiet_NaN());
11+
stan::test::expect_ad(f(1), 0.5);
12+
stan::test::expect_ad(f(1), std::numeric_limits<double>::quiet_NaN());
13+
stan::test::expect_ad(f(1), 0.2);
14+
15+
std::vector<int> std_in1{0, 1};
16+
Eigen::VectorXd in2(2);
17+
in2 << 0.5, 0.9;
18+
19+
stan::test::expect_ad(f(std_in1), 0.2);
20+
stan::test::expect_ad(f(std_in1), std::numeric_limits<double>::quiet_NaN());
21+
stan::test::expect_ad(f(1), in2);
22+
stan::test::expect_ad(f(std_in1), in2);
23+
stan::test::expect_ad(f(std_in1), std::numeric_limits<double>::quiet_NaN());
24+
}

0 commit comments

Comments
 (0)