Skip to content

Commit c82075e

Browse files
committed
Review comments
1 parent 6bd8723 commit c82075e

5 files changed

Lines changed: 59 additions & 39 deletions

File tree

stan/math/prim/fun/all.hpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@ namespace math {
1111
/**
1212
* Return true if all values in the input are true.
1313
*
14-
* Overload for a single boolean input
14+
* Overload for a single integral input
1515
*
16+
* @tparam T The type of integral input.
1617
* @param x boolean input
1718
* @return The input unchanged
1819
*/
19-
template <typename T, require_integral_t<T>* = nullptr>
20-
constexpr inline T all(T x) {
20+
template <typename T, require_t<std::is_convertible<T, bool>>* = nullptr>
21+
constexpr inline bool all(T x) {
2122
return x;
2223
}
2324

@@ -37,20 +38,23 @@ inline bool all(const ContainerT& x) {
3738
return x.all();
3839
}
3940

41+
// Forward-declaration for correct resolution of all(std::vector<std::tuple>)
42+
template <typename... Types>
43+
inline bool all(const std::tuple<Types...>& x);
44+
4045
/**
4146
* Return true if all values in the input are true.
4247
*
4348
* Overload for a std::vector/nested inputs. The Eigen::Map/apply_vector_unary
4449
* approach cannot be used as std::vector<bool> types do not have a .data()
4550
* member and are not always stored contiguously.
4651
*
47-
* @tparam Type of container
52+
* @tparam InnerT Type within std::vector
4853
* @param x Nested container of boolean inputs
4954
* @return Boolean indicating whether all elements are true
5055
*/
51-
template <typename ContainerT,
52-
require_std_vector_st<std::is_integral, ContainerT>* = nullptr>
53-
inline bool all(const ContainerT& x) {
56+
template <typename InnerT>
57+
inline bool all(const std::vector<InnerT>& x) {
5458
return std::all_of(x.begin(), x.end(), [](const auto& i) { return all(i); });
5559
}
5660

stan/math/prim/fun/any.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ namespace math {
1313
*
1414
* Overload for a single boolean input
1515
*
16+
* @tparam T The type of integral input.
1617
* @param x boolean input
1718
* @return The input unchanged
1819
*/
19-
template <typename T, require_integral_t<T>* = nullptr>
20-
constexpr inline T any(T x) {
20+
template <typename T, require_t<std::is_convertible<T, bool>>* = nullptr>
21+
constexpr inline bool any(T x) {
2122
return x;
2223
}
2324

@@ -37,20 +38,23 @@ inline bool any(const ContainerT& x) {
3738
return x.any();
3839
}
3940

41+
// Forward-declaration for correct resolution of any(std::vector<std::tuple>)
42+
template <typename... Types>
43+
inline bool any(const std::tuple<Types...>& x);
44+
4045
/**
4146
* Return true if any values in the input are true.
4247
*
4348
* Overload for a std::vector/nested inputs. The Eigen::Map/apply_vector_unary
4449
* approach cannot be used as std::vector<bool> types do not have a .data()
4550
* member and are not always stored contiguously.
4651
*
47-
* @tparam Type of container within std::vector
52+
* @tparam InnerT Type within std::vector
4853
* @param x Nested container of boolean inputs
4954
* @return Boolean indicating whether any elements are true
5055
*/
51-
template <typename ContainerT,
52-
require_std_vector_st<std::is_integral, ContainerT>* = nullptr>
53-
inline bool any(const ContainerT& x) {
56+
template <typename InnerT>
57+
inline bool any(const std::vector<InnerT>& x) {
5458
return std::any_of(x.begin(), x.end(), [](const auto& i) { return any(i); });
5559
}
5660

stan/math/prim/fun/select.hpp

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ namespace stan {
99
namespace math {
1010

1111
/**
12-
* Return the second argument if the first argument is true
13-
* and otherwise return the third argument.
12+
* If first argument is true return the second argument,
13+
* else return the third argument.
1414
*
1515
* `select(c, y1, y0) = c ? y1 : y0`.
1616
*
17-
* @tparam T_true type of the true argument
18-
* @tparam T_false type of the false argument
17+
* @tparam T_true A stan `Scalar` type
18+
* @tparam T_false A stan `Scalar` type
1919
* @param c Boolean condition value.
2020
* @param y_true Value to return if condition is true.
2121
* @param y_false Value to return if condition is false.
@@ -29,17 +29,17 @@ inline ReturnT select(const bool c, const T_true y_true,
2929
}
3030

3131
/**
32-
* Return the second argument if the first argument is true
33-
* and otherwise return the third argument. Eigen expressions are
32+
* If first argument is true return the second argument,
33+
* else return the third argument. Eigen expressions are
3434
* evaluated so that the return type is the same for both branches.
3535
*
3636
* Both containers must have the same plain type. The scalar type
3737
* of the return is determined by the return_type_t<> type trait.
3838
*
3939
* Overload for use with two containers.
4040
*
41-
* @tparam T_true type of the true argument
42-
* @tparam T_false type of the false argument
41+
* @tparam T_true A container of stan `Scalar` types
42+
* @tparam T_false A container of stan `Scalar` types
4343
* @param c Boolean condition value.
4444
* @param y_true Value to return if condition is true.
4545
* @param y_false Value to return if condition is false.
@@ -59,8 +59,8 @@ inline T_true_plain select(const bool c, T_true&& y_true, T_false&& y_false) {
5959
}
6060

6161
/**
62-
* Return the second argument if the first argument is true
63-
* and otherwise return the third argument.
62+
* If first argument is true return the second argument,
63+
* else return the third argument.
6464
*
6565
* Overload for use when the 'true' return is a container and the 'false'
6666
* return is a scalar
@@ -69,8 +69,8 @@ inline T_true_plain select(const bool c, T_true&& y_true, T_false&& y_false) {
6969
* plain type as the provided argument. Consequently, any Eigen expressions are
7070
* evaluated.
7171
*
72-
* @tparam T_true type of the true argument
73-
* @tparam T_false type of the false argument
72+
* @tparam T_true A container of stan `Scalar` types
73+
* @tparam T_false A stan `Scalar` type
7474
* @param c Boolean condition value.
7575
* @param y_true Value to return if condition is true.
7676
* @param y_false Value to return if condition is false.
@@ -94,8 +94,8 @@ inline ReturnT select(const bool c, const T_true& y_true,
9494
}
9595

9696
/**
97-
* Return the second argument if the first argument is true
98-
* and otherwise return the third argument.
97+
* If first argument is true return the second argument,
98+
* else return the third argument.
9999
*
100100
* Overload for use when the 'true' return is a scalar and the 'false'
101101
* return is a container
@@ -104,8 +104,8 @@ inline ReturnT select(const bool c, const T_true& y_true,
104104
* plain type as the provided argument. Consequently, any Eigen expressions are
105105
* evaluated.
106106
*
107-
* @tparam T_true type of the true argument
108-
* @tparam T_false type of the false argument
107+
* @tparam T_true A stan `Scalar` type
108+
* @tparam T_false A container of stan `Scalar` types
109109
* @param c Boolean condition value.
110110
* @param y_true Value to return if condition is true.
111111
* @param y_false Value to return if condition is false.
@@ -129,16 +129,16 @@ inline ReturnT select(const bool c, const T_true y_true,
129129
}
130130

131131
/**
132-
* Return the second argument if the first argument is true
133-
* and otherwise return the third argument. Overload for use with an Eigen
132+
* If first argument is true return the second argument,
133+
* else return the third argument. Overload for use with an Eigen
134134
* object of booleans, and two scalars.
135135
*
136136
* The chosen scalar is returned as an Eigen object of the same dimension
137137
* as the input Eigen argument
138138
*
139139
* @tparam T_bool type of Eigen boolean object
140-
* @tparam T_true type of the true argument
141-
* @tparam T_false type of the false argument
140+
* @tparam T_true A stan `Scalar` type
141+
* @tparam T_false A stan `Scalar` type
142142
* @param c Eigen object of boolean condition values.
143143
* @param y_true Value to return if condition is true.
144144
* @param y_false Value to return if condition is false.
@@ -147,20 +147,22 @@ template <typename T_bool, typename T_true, typename T_false,
147147
require_eigen_array_vt<std::is_integral, T_bool>* = nullptr,
148148
require_all_stan_scalar_t<T_true, T_false>* = nullptr>
149149
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
150+
using ret_t = return_type_t<T_true, T_false>;
150151
return c
151152
.unaryExpr(
152-
[y_true, y_false](bool cond) { return cond ? y_true : y_false; })
153+
[y_true, y_false](bool cond) {
154+
return cond ? ret_t(y_true) : ret_t(y_false); })
153155
.eval();
154156
}
155157

156158
/**
157-
* Return the second argument if the first argument is true
158-
* and otherwise return the third argument. Overload for use with an Eigen
159+
* If first argument is true return the second argument,
160+
* else return the third argument. Overload for use with an Eigen
159161
* array of booleans, one Eigen array and a scalar as input.
160162
*
161163
* @tparam T_bool type of Eigen boolean object
162-
* @tparam T_true type of the true argument
163-
* @tparam T_false type of the false argument
164+
* @tparam T_true A stan `Scalar` type or Eigen Array type
165+
* @tparam T_false A stan `Scalar` type or Eigen Array type
164166
* @param c Eigen object of boolean condition values.
165167
* @param y_true Value to return if condition is true.
166168
* @param y_false Value to return if condition is false.
@@ -171,7 +173,8 @@ template <typename T_bool, typename T_true, typename T_false,
171173
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) {
172174
check_consistent_sizes("select", "boolean", c, "left hand side", y_true,
173175
"right hand side", y_false);
174-
return c.select(y_true, y_false).eval();
176+
using ret_t = return_type_t<T_true, T_false>;
177+
return c.select(y_true, y_false).template cast<ret_t>().eval();
175178
}
176179

177180
} // namespace math

test/unit/math/prim/fun/all_test.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <gtest/gtest.h>
33
#include <cmath>
44
#include <limits>
5+
#include <vector>
56

67
TEST(MathFunctions, all) {
78
using stan::math::all;
@@ -23,20 +24,24 @@ TEST(MathFunctions, all) {
2324

2425
auto bool_tuple
2526
= std::make_tuple(true, bool_stdvector, bool_array, nested_bool);
27+
std::vector<decltype(bool_tuple)> stdvec_bool_tuple{bool_tuple};
2628

2729
EXPECT_FALSE(all(inp < 2));
2830
EXPECT_TRUE(all(bool_stdvector));
2931
EXPECT_TRUE(all(bool_array));
3032
EXPECT_TRUE(all(nested_bool));
3133
EXPECT_TRUE(all(bool_tuple));
34+
EXPECT_TRUE(all(stdvec_bool_tuple));
3235

3336
bool_array(2) = false;
3437
nested_bool[1](3) = false;
3538
bool_stdvector[1] = false;
3639
std::get<3>(bool_tuple) = nested_bool;
40+
stdvec_bool_tuple[0] = bool_tuple;
3741

3842
EXPECT_FALSE(all(bool_array));
3943
EXPECT_FALSE(all(nested_bool));
4044
EXPECT_FALSE(all(bool_stdvector));
4145
EXPECT_FALSE(all(bool_tuple));
46+
EXPECT_FALSE(all(stdvec_bool_tuple));
4247
}

test/unit/math/prim/fun/any_test.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,24 @@ TEST(MathFunctions, any) {
2323

2424
auto bool_tuple
2525
= std::make_tuple(false, bool_stdvector, bool_array, nested_bool);
26+
std::vector<decltype(bool_tuple)> stdvec_bool_tuple{bool_tuple};
2627

2728
EXPECT_TRUE(any(inp < 2));
2829
EXPECT_FALSE(any(bool_stdvector));
2930
EXPECT_FALSE(any(bool_array));
3031
EXPECT_FALSE(any(nested_bool));
3132
EXPECT_FALSE(any(bool_tuple));
33+
EXPECT_FALSE(any(stdvec_bool_tuple));
3234

3335
bool_array(2) = true;
3436
nested_bool[1](3) = true;
3537
bool_stdvector[1] = true;
3638
std::get<3>(bool_tuple) = nested_bool;
39+
stdvec_bool_tuple[0] = bool_tuple;
3740

3841
EXPECT_TRUE(any(bool_array));
3942
EXPECT_TRUE(any(nested_bool));
4043
EXPECT_TRUE(any(bool_stdvector));
4144
EXPECT_TRUE(any(bool_tuple));
45+
EXPECT_TRUE(any(stdvec_bool_tuple));
4246
}

0 commit comments

Comments
 (0)