@@ -36,11 +36,17 @@ inline auto select(const bool c, const T_true y_true, const T_false y_false) {
3636 * @param y_false Value to return if condition is false.
3737 */
3838template <typename T_true, typename T_false,
39- require_all_eigen_t <T_true, T_false>* = nullptr >
40- inline auto select (const bool c, const T_true y_true, const T_false y_false) {
41- return y_true
42- .binaryExpr (y_false, [&](auto && x, auto && y) { return c ? x : y; })
43- .eval ();
39+ typename T_return = return_type_t <T_true, T_false>,
40+ typename T_true_plain = promote_scalar_t <T_return, plain_type_t <T_true>>,
41+ typename T_false_plain = promote_scalar_t <T_return, plain_type_t <T_false>>,
42+ require_all_eigen_t <T_true, T_false>* = nullptr ,
43+ require_all_same_t <T_true_plain, T_false_plain>* = nullptr >
44+ inline T_true_plain select (const bool c, const T_true y_true, const T_false y_false) {
45+ if (c) {
46+ return y_true;
47+ } else {
48+ return y_false;
49+ }
4450}
4551
4652/* *
@@ -64,9 +70,9 @@ inline ReturnT select(const bool c, const T_true& y_true,
6470 const T_false& y_false) {
6571 if (c) {
6672 return y_true;
73+ } else {
74+ return y_true.unaryExpr ([&](auto && y) { return y_false; });
6775 }
68-
69- return y_true.unaryExpr ([&](auto && y) { return y_false; });
7076}
7177
7278/* *
@@ -90,9 +96,9 @@ inline ReturnT select(const bool c, const T_true y_true,
9096 const T_false y_false) {
9197 if (c) {
9298 return y_false.unaryExpr ([&](auto && y) { return y_true; });
99+ } else {
100+ return y_false;
93101 }
94-
95- return y_false;
96102}
97103
98104/* *
@@ -129,7 +135,8 @@ inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
129135 */
130136template <typename T_bool, typename T_true, typename T_false,
131137 require_eigen_array_t <T_bool>* = nullptr ,
132- require_any_eigen_array_t <T_true, T_false>* = nullptr >
138+ require_any_eigen_array_t <T_true, T_false>* = nullptr ,
139+ require_any_stan_scalar_t <T_true, T_false>* = nullptr >
133140inline auto select (const T_bool c, const T_true y_true, const T_false y_false) {
134141 return c.select (y_true, y_false).eval ();
135142}
0 commit comments