33
44#include < stan/math/prim/meta.hpp>
55#include < stan/math/prim/err.hpp>
6- #include < stan/math/prim/fun/constants .hpp>
6+ #include < stan/math/prim/fun/any .hpp>
77#include < stan/math/prim/fun/inv.hpp>
88#include < stan/math/prim/fun/log.hpp>
9- #include < stan/math/prim/fun/max_size.hpp>
10- #include < stan/math/prim/fun/scalar_seq_view.hpp>
119#include < stan/math/prim/fun/select.hpp>
12- #include < stan/math/prim/fun/size.hpp>
1310#include < stan/math/prim/fun/size_zero.hpp>
14- #include < stan/math/prim/fun/value_of.hpp>
1511#include < stan/math/prim/functor/partials_propagator.hpp>
16- #include < cmath>
1712
1813namespace stan {
1914namespace math {
@@ -35,14 +30,13 @@ template <typename T_n, typename T_prob,
3530 T_n, T_prob>* = nullptr >
3631return_type_t <T_prob> bernoulli_lccdf (const T_n& n, const T_prob& theta) {
3732 using T_theta_ref = ref_type_t <T_prob>;
38- using std::log;
3933 static const char * function = " bernoulli_lccdf" ;
4034 check_consistent_sizes (function, " Random variable" , n,
4135 " Probability parameter" , theta);
4236 T_theta_ref theta_ref = theta;
4337 const auto & n_arr = as_array_or_scalar (n);
44- check_bounded (function, " Probability parameter " , value_of (theta_ref), 0.0 ,
45- 1.0 );
38+ const auto & theta_arr = as_value_column_array_or_scalar (theta_ref);
39+ check_bounded (function, " Probability parameter " , theta_arr, 0.0 , 1.0 );
4640
4741 if (size_zero (n, theta)) {
4842 return 0.0 ;
@@ -52,19 +46,19 @@ return_type_t<T_prob> bernoulli_lccdf(const T_n& n, const T_prob& theta) {
5246
5347 // Explicit return for extreme values
5448 // The gradients are technically ill-defined, but treated as zero
55- if (sum (n_arr < 0 )) {
49+ if (any (n_arr < 0 )) {
5650 return ops_partials.build (0.0 );
57- } else if (sum (n_arr >= 1 )) {
51+ } else if (any (n_arr >= 1 )) {
5852 return ops_partials.build (NEGATIVE_INFTY);
5953 }
6054
61- const auto & theta_arr = as_value_column_array_or_scalar (theta_ref );
55+ const auto & theta_broadcast = select ( true , theta_arr, n_arr );
6256
6357 if (!is_constant_all<T_prob>::value) {
64- partials<0 >(ops_partials) = select ( true , inv (theta_arr), n_arr );
58+ partials<0 >(ops_partials) = inv (theta_broadcast );
6559 }
6660
67- return ops_partials.build (sum (select ( true , log (theta_arr), n_arr )));
61+ return ops_partials.build (sum (log (theta_broadcast )));
6862}
6963
7064} // namespace math
0 commit comments