@@ -9,13 +9,13 @@ namespace stan {
99namespace math {
1010
1111/* *
12- * Return the second argument if the first argument is true
13- * and otherwise return the third argument.
12+ * If first argument is true return the second argument,
13+ * else return the third argument.
1414 *
1515 * `select(c, y1, y0) = c ? y1 : y0`.
1616 *
17- * @tparam T_true type of the true argument
18- * @tparam T_false type of the false argument
17+ * @tparam T_true A stan `Scalar` type
18+ * @tparam T_false A stan `Scalar` type
1919 * @param c Boolean condition value.
2020 * @param y_true Value to return if condition is true.
2121 * @param y_false Value to return if condition is false.
@@ -29,17 +29,17 @@ inline ReturnT select(const bool c, const T_true y_true,
2929}
3030
3131/* *
32- * Return the second argument if the first argument is true
33- * and otherwise return the third argument. Eigen expressions are
32+ * If first argument is true return the second argument,
33+ * else return the third argument. Eigen expressions are
3434 * evaluated so that the return type is the same for both branches.
3535 *
3636 * Both containers must have the same plain type. The scalar type
3737 * of the return is determined by the return_type_t<> type trait.
3838 *
3939 * Overload for use with two containers.
4040 *
41- * @tparam T_true type of the true argument
42- * @tparam T_false type of the false argument
41+ * @tparam T_true A container of stan `Scalar` types
42+ * @tparam T_false A container of stan `Scalar` types
4343 * @param c Boolean condition value.
4444 * @param y_true Value to return if condition is true.
4545 * @param y_false Value to return if condition is false.
@@ -59,8 +59,8 @@ inline T_true_plain select(const bool c, T_true&& y_true, T_false&& y_false) {
5959}
6060
6161/* *
62- * Return the second argument if the first argument is true
63- * and otherwise return the third argument.
62+ * If first argument is true return the second argument,
63+ * else return the third argument.
6464 *
6565 * Overload for use when the 'true' return is a container and the 'false'
6666 * return is a scalar
@@ -69,8 +69,8 @@ inline T_true_plain select(const bool c, T_true&& y_true, T_false&& y_false) {
6969 * plain type as the provided argument. Consequently, any Eigen expressions are
7070 * evaluated.
7171 *
72- * @tparam T_true type of the true argument
73- * @tparam T_false type of the false argument
72+ * @tparam T_true A container of stan `Scalar` types
73+ * @tparam T_false A stan `Scalar` type
7474 * @param c Boolean condition value.
7575 * @param y_true Value to return if condition is true.
7676 * @param y_false Value to return if condition is false.
@@ -94,8 +94,8 @@ inline ReturnT select(const bool c, const T_true& y_true,
9494}
9595
9696/* *
97- * Return the second argument if the first argument is true
98- * and otherwise return the third argument.
97+ * If first argument is true return the second argument,
98+ * else return the third argument.
9999 *
100100 * Overload for use when the 'true' return is a scalar and the 'false'
101101 * return is a container
@@ -104,8 +104,8 @@ inline ReturnT select(const bool c, const T_true& y_true,
104104 * plain type as the provided argument. Consequently, any Eigen expressions are
105105 * evaluated.
106106 *
107- * @tparam T_true type of the true argument
108- * @tparam T_false type of the false argument
107+ * @tparam T_true A stan `Scalar` type
108+ * @tparam T_false A container of stan `Scalar` types
109109 * @param c Boolean condition value.
110110 * @param y_true Value to return if condition is true.
111111 * @param y_false Value to return if condition is false.
@@ -129,16 +129,16 @@ inline ReturnT select(const bool c, const T_true y_true,
129129}
130130
131131/* *
132- * Return the second argument if the first argument is true
133- * and otherwise return the third argument. Overload for use with an Eigen
132+ * If first argument is true return the second argument,
133+ * else return the third argument. Overload for use with an Eigen
134134 * object of booleans, and two scalars.
135135 *
136136 * The chosen scalar is returned as an Eigen object of the same dimension
137137 * as the input Eigen argument
138138 *
139139 * @tparam T_bool type of Eigen boolean object
140- * @tparam T_true type of the true argument
141- * @tparam T_false type of the false argument
140+ * @tparam T_true A stan `Scalar` type
141+ * @tparam T_false A stan `Scalar` type
142142 * @param c Eigen object of boolean condition values.
143143 * @param y_true Value to return if condition is true.
144144 * @param y_false Value to return if condition is false.
@@ -147,20 +147,22 @@ template <typename T_bool, typename T_true, typename T_false,
147147 require_eigen_array_vt<std::is_integral, T_bool>* = nullptr ,
148148 require_all_stan_scalar_t <T_true, T_false>* = nullptr >
149149inline auto select (const T_bool c, const T_true y_true, const T_false y_false) {
150+ using ret_t = return_type_t <T_true, T_false>;
150151 return c
151152 .unaryExpr (
152- [y_true, y_false](bool cond) { return cond ? y_true : y_false; })
153+ [y_true, y_false](bool cond) {
154+ return cond ? ret_t (y_true) : ret_t (y_false); })
153155 .eval ();
154156}
155157
156158/* *
157- * Return the second argument if the first argument is true
158- * and otherwise return the third argument. Overload for use with an Eigen
159+ * If first argument is true return the second argument,
160+ * else return the third argument. Overload for use with an Eigen
159161 * array of booleans, one Eigen array and a scalar as input.
160162 *
161163 * @tparam T_bool type of Eigen boolean object
162- * @tparam T_true type of the true argument
163- * @tparam T_false type of the false argument
164+ * @tparam T_true A stan `Scalar` type or Eigen Array type
165+ * @tparam T_false A stan `Scalar` type or Eigen Array type
164166 * @param c Eigen object of boolean condition values.
165167 * @param y_true Value to return if condition is true.
166168 * @param y_false Value to return if condition is false.
@@ -171,7 +173,8 @@ template <typename T_bool, typename T_true, typename T_false,
171173inline auto select (const T_bool c, const T_true y_true, const T_false y_false) {
172174 check_consistent_sizes (" select" , " boolean" , c, " left hand side" , y_true,
173175 " right hand side" , y_false);
174- return c.select (y_true, y_false).eval ();
176+ using ret_t = return_type_t <T_true, T_false>;
177+ return c.select (y_true, y_false).template cast <ret_t >().eval ();
175178}
176179
177180} // namespace math
0 commit comments