Skip to content

Commit da1f7a8

Browse files
committed
review comments
1 parent a4b384c commit da1f7a8

2 files changed

Lines changed: 18 additions & 12 deletions

File tree

stan/math/prim/fun/select.hpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,17 @@ inline auto select(const bool c, const T_true y_true, const T_false y_false) {
3636
* @param y_false Value to return if condition is false.
3737
*/
3838
template <typename T_true, typename T_false,
39-
require_all_eigen_t<T_true, T_false>* = nullptr>
40-
inline auto select(const bool c, const T_true y_true, const T_false y_false) {
41-
return y_true
42-
.binaryExpr(y_false, [&](auto&& x, auto&& y) { return c ? x : y; })
43-
.eval();
39+
typename T_return = return_type_t<T_true, T_false>,
40+
typename T_true_plain = promote_scalar_t<T_return, plain_type_t<T_true>>,
41+
typename T_false_plain = promote_scalar_t<T_return, plain_type_t<T_false>>,
42+
require_all_eigen_t<T_true, T_false>* = nullptr,
43+
require_all_same_t<T_true_plain, T_false_plain>* = nullptr>
44+
inline T_true_plain select(const bool c, const T_true y_true, const T_false y_false) {
45+
if (c) {
46+
return y_true;
47+
} else {
48+
return y_false;
49+
}
4450
}
4551

4652
/**
@@ -64,9 +70,9 @@ inline ReturnT select(const bool c, const T_true& y_true,
6470
const T_false& y_false) {
6571
if (c) {
6672
return y_true;
73+
} else {
74+
return y_true.unaryExpr([&](auto&& y) { return y_false; });
6775
}
68-
69-
return y_true.unaryExpr([&](auto&& y) { return y_false; });
7076
}
7177

7278
/**
@@ -90,9 +96,9 @@ inline ReturnT select(const bool c, const T_true y_true,
9096
const T_false y_false) {
9197
if (c) {
9298
return y_false.unaryExpr([&](auto&& y) { return y_true; });
99+
} else {
100+
return y_false;
93101
}
94-
95-
return y_false;
96102
}
97103

98104
/**
@@ -129,7 +135,8 @@ inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
129135
*/
130136
template <typename T_bool, typename T_true, typename T_false,
131137
require_eigen_array_t<T_bool>* = nullptr,
132-
require_any_eigen_array_t<T_true, T_false>* = nullptr>
138+
require_any_eigen_array_t<T_true, T_false>* = nullptr,
139+
require_any_stan_scalar_t<T_true, T_false>* = nullptr>
133140
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
134141
return c.select(y_true, y_false).eval();
135142
}

stan/math/prim/prob/bernoulli_lccdf.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ return_type_t<T_prob> bernoulli_lccdf(const T_n& n, const T_prob& theta) {
5454
// The gradients are technically ill-defined, but treated as zero
5555
if (sum(n_arr < 0)) {
5656
return ops_partials.build(0.0);
57-
}
58-
if (sum(n_arr >= 1)) {
57+
} else if (sum(n_arr >= 1)) {
5958
return ops_partials.build(NEGATIVE_INFTY);
6059
}
6160

0 commit comments

Comments
 (0)