Skip to content

Commit 38289cd

Browse files
authored
Merge pull request #2853 from andrjohns/vectorised-select
Add vectorised select(), any(), and all() functions
2 parents 34881d4 + 64728a0 commit 38289cd

8 files changed

Lines changed: 656 additions & 16 deletions

File tree

stan/math/opencl/kernel_generator/select.hpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#ifdef STAN_OPENCL
44

55
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/fun/select.hpp>
67
#include <stan/math/opencl/matrix_cl_view.hpp>
78
#include <stan/math/opencl/kernel_generator/type_str.hpp>
89
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
@@ -150,22 +151,6 @@ select(T_condition&& condition, T_then&& then, T_else&& els) { // NOLINT
150151
as_operation_cl(std::forward<T_else>(els))};
151152
}
152153

153-
/**
154-
* Scalar overload of the selection operation.
155-
* @tparam T_then type of then scalar
156-
* @tparam T_else type of else scalar
157-
* @param condition condition
158-
* @param then then result
159-
* @param els else result
160-
* @return `condition ? then : els`
161-
*/
162-
template <typename T_then, typename T_else,
163-
require_all_arithmetic_t<T_then, T_else>* = nullptr>
164-
inline std::common_type_t<T_then, T_else> select(bool condition, T_then then,
165-
T_else els) {
166-
return condition ? then : els;
167-
}
168-
169154
/** @}*/
170155
} // namespace math
171156
} // namespace stan

stan/math/prim/fun.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <stan/math/prim/fun/acosh.hpp>
88
#include <stan/math/prim/fun/add.hpp>
99
#include <stan/math/prim/fun/add_diag.hpp>
10+
#include <stan/math/prim/fun/all.hpp>
11+
#include <stan/math/prim/fun/any.hpp>
1012
#include <stan/math/prim/fun/append_array.hpp>
1113
#include <stan/math/prim/fun/append_col.hpp>
1214
#include <stan/math/prim/fun/append_row.hpp>
@@ -305,6 +307,7 @@
305307
#include <stan/math/prim/fun/scaled_add.hpp>
306308
#include <stan/math/prim/fun/sd.hpp>
307309
#include <stan/math/prim/fun/segment.hpp>
310+
#include <stan/math/prim/fun/select.hpp>
308311
#include <stan/math/prim/fun/sign.hpp>
309312
#include <stan/math/prim/fun/signbit.hpp>
310313
#include <stan/math/prim/fun/simplex_constrain.hpp>

stan/math/prim/fun/all.hpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#ifndef STAN_MATH_PRIM_FUN_ALL_HPP
2+
#define STAN_MATH_PRIM_FUN_ALL_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/functor/for_each.hpp>
6+
#include <algorithm>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Return true if all values in the input are true.
13+
*
14+
* Overload for a single integral input
15+
*
16+
* @tparam T Any type convertible to `bool`
17+
* @param x integral input
18+
* @return The input unchanged
19+
*/
20+
template <typename T, require_t<std::is_convertible<T, bool>>* = nullptr>
21+
constexpr inline bool all(T x) {
22+
return x;
23+
}
24+
25+
/**
26+
* Return true if all values in the input are true.
27+
*
28+
* Overload for Eigen types
29+
*
30+
* @tparam ContainerT A type derived from `Eigen::EigenBase` that has an
31+
* `integral` scalar type
32+
* @param x Eigen object of boolean inputs
33+
* @return Boolean indicating whether all elements are true
34+
*/
35+
template <typename ContainerT,
36+
require_eigen_st<std::is_integral, ContainerT>* = nullptr>
37+
inline bool all(const ContainerT& x) {
38+
return x.all();
39+
}
40+
41+
// Forward-declaration for correct resolution of all(std::vector<std::tuple>)
42+
template <typename... Types>
43+
inline bool all(const std::tuple<Types...>& x);
44+
45+
/**
46+
* Return true if all values in the input are true.
47+
*
48+
* Overload for a std::vector/nested inputs. The Eigen::Map/apply_vector_unary
49+
* approach cannot be used as std::vector<bool> types do not have a .data()
50+
* member and are not always stored contiguously.
51+
*
52+
* @tparam InnerT Type within std::vector
53+
* @param x Nested container of boolean inputs
54+
* @return Boolean indicating whether all elements are true
55+
*/
56+
template <typename InnerT>
57+
inline bool all(const std::vector<InnerT>& x) {
58+
return std::all_of(x.begin(), x.end(), [](const auto& i) { return all(i); });
59+
}
60+
61+
/**
62+
* Return true if all values in the input are true.
63+
*
64+
* Overload for a tuple input.
65+
*
66+
* @tparam Types of items within tuple
67+
* @param x Tuple of boolean scalar-type elements
68+
* @return Boolean indicating whether all elements are true
69+
*/
70+
template <typename... Types>
71+
inline bool all(const std::tuple<Types...>& x) {
72+
bool all_true = true;
73+
math::for_each(
74+
[&all_true](const auto& i) {
75+
all_true = all_true && all(i);
76+
return;
77+
},
78+
x);
79+
return all_true;
80+
}
81+
82+
} // namespace math
83+
} // namespace stan
84+
85+
#endif

stan/math/prim/fun/any.hpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#ifndef STAN_MATH_PRIM_FUN_ANY_HPP
2+
#define STAN_MATH_PRIM_FUN_ANY_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/functor/for_each.hpp>
6+
#include <algorithm>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Return true if any values in the input are true.
13+
*
14+
* Overload for a single boolean input
15+
*
16+
* @tparam T Any type convertible to `bool`
17+
* @param x boolean input
18+
* @return The input unchanged
19+
*/
20+
template <typename T, require_t<std::is_convertible<T, bool>>* = nullptr>
21+
constexpr inline bool any(T x) {
22+
return x;
23+
}
24+
25+
/**
26+
* Return true if any values in the input are true.
27+
*
28+
* Overload for Eigen types
29+
*
30+
* @tparam ContainerT A type derived from `Eigen::EigenBase` that has an
31+
* `integral` scalar type
32+
* @param x Eigen object of boolean inputs
33+
* @return Boolean indicating whether any elements are true
34+
*/
35+
template <typename ContainerT,
36+
require_eigen_st<std::is_integral, ContainerT>* = nullptr>
37+
inline bool any(const ContainerT& x) {
38+
return x.any();
39+
}
40+
41+
// Forward-declaration for correct resolution of any(std::vector<std::tuple>)
42+
template <typename... Types>
43+
inline bool any(const std::tuple<Types...>& x);
44+
45+
/**
46+
* Return true if any values in the input are true.
47+
*
48+
* Overload for a std::vector/nested inputs. The Eigen::Map/apply_vector_unary
49+
* approach cannot be used as std::vector<bool> types do not have a .data()
50+
* member and are not always stored contiguously.
51+
*
52+
* @tparam InnerT Type within std::vector
53+
* @param x Nested container of boolean inputs
54+
* @return Boolean indicating whether any elements are true
55+
*/
56+
template <typename InnerT>
57+
inline bool any(const std::vector<InnerT>& x) {
58+
return std::any_of(x.begin(), x.end(), [](const auto& i) { return any(i); });
59+
}
60+
61+
/**
62+
* Return true if any values in the input are true.
63+
*
64+
* Overload for a tuple input.
65+
*
66+
* @tparam Types of items within tuple
67+
* @param x Tuple of boolean scalar-type elements
68+
* @return Boolean indicating whether any elements are true
69+
*/
70+
template <typename... Types>
71+
inline bool any(const std::tuple<Types...>& x) {
72+
bool any_true = false;
73+
math::for_each(
74+
[&any_true](const auto& i) {
75+
any_true = any_true || any(i);
76+
return;
77+
},
78+
x);
79+
return any_true;
80+
}
81+
82+
} // namespace math
83+
} // namespace stan
84+
85+
#endif

0 commit comments

Comments
 (0)