33
44#include < stan/math/prim/meta.hpp>
55#include < stan/math/prim/err.hpp>
6+ #include < stan/math/prim/fun/any.hpp>
67#include < stan/math/prim/fun/constants.hpp>
7- #include < stan/math/prim/fun/inv.hpp>
8- #include < stan/math/prim/fun/log.hpp>
9- #include < stan/math/prim/fun/max_size.hpp>
10- #include < stan/math/prim/fun/scalar_seq_view.hpp>
11- #include < stan/math/prim/fun/size.hpp>
8+ #include < stan/math/prim/fun/select.hpp>
129#include < stan/math/prim/fun/size_zero.hpp>
13- #include < stan/math/prim/fun/value_of.hpp>
1410#include < stan/math/prim/functor/partials_propagator.hpp>
15- #include < cmath>
1611
1712namespace stan {
1813namespace math {
@@ -33,52 +28,34 @@ template <typename T_n, typename T_prob,
3328 require_all_not_nonscalar_prim_or_rev_kernel_expression_t <
3429 T_n, T_prob>* = nullptr >
3530return_type_t <T_prob> bernoulli_lcdf (const T_n& n, const T_prob& theta) {
36- using T_partials_return = partials_return_t <T_n, T_prob>;
3731 using T_theta_ref = ref_type_t <T_prob>;
38- using std::log;
3932 static const char * function = " bernoulli_lcdf" ;
4033 check_consistent_sizes (function, " Random variable" , n,
4134 " Probability parameter" , theta);
4235 T_theta_ref theta_ref = theta;
43- check_bounded (function, " Probability parameter" , value_of (theta_ref), 0.0 ,
44- 1.0 );
36+ const auto & n_arr = as_array_or_scalar (n);
37+ const auto & theta_arr = as_value_column_array_or_scalar (theta_ref);
38+ check_bounded (function, " Probability parameter" , theta_arr, 0.0 , 1.0 );
4539
4640 if (size_zero (n, theta)) {
4741 return 0.0 ;
4842 }
4943
50- T_partials_return P (0.0 );
5144 auto ops_partials = make_partials_propagator (theta_ref);
5245
53- scalar_seq_view<T_n> n_vec (n);
54- scalar_seq_view<T_theta_ref> theta_vec (theta_ref);
55- size_t max_size_seq_view = max_size (n, theta);
56-
5746 // Explicit return for extreme values
5847 // The gradients are technically ill-defined, but treated as zero
59- for (size_t i = 0 ; i < stan::math::size (n); i++) {
60- if (n_vec.val (i) < 0 ) {
61- return ops_partials.build (NEGATIVE_INFTY);
62- }
48+ if (any (n_arr < 0 )) {
49+ return ops_partials.build (NEGATIVE_INFTY);
6350 }
6451
65- for (size_t i = 0 ; i < max_size_seq_view; i++) {
66- // Explicit results for extreme values
67- // The gradients are technically ill-defined, but treated as zero
68- if (n_vec.val (i) >= 1 ) {
69- continue ;
70- }
71-
72- const T_partials_return Pi = 1 - theta_vec.val (i);
73-
74- P += log (Pi);
52+ const auto & log1m_theta = select (theta_arr == 1 , 0.0 , log1m (theta_arr));
7553
76- if (!is_constant_all<T_prob>::value) {
77- partials<0 >(ops_partials)[i] -= inv (Pi);
78- }
54+ if (!is_constant_all<T_prob>::value) {
55+ partials<0 >(ops_partials) = select (n_arr == 0 , -exp (-log1m_theta), 0.0 );
7956 }
8057
81- return ops_partials.build (P );
58+ return ops_partials.build (sum ( select (n_arr == 0 , log1m_theta, 0.0 )) );
8259}
8360
8461} // namespace math
0 commit comments