Skip to content

Commit c09bcf4

Browse files
committed
Merge conflict
2 parents 85c9e5e + 5d1fd38 commit c09bcf4

63 files changed

Lines changed: 1896 additions & 88 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

stan/math/opencl/kernel_generator/select.hpp

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#ifdef STAN_OPENCL
44

55
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/fun/select.hpp>
67
#include <stan/math/opencl/matrix_cl_view.hpp>
78
#include <stan/math/opencl/kernel_generator/type_str.hpp>
89
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
@@ -150,22 +151,6 @@ select(T_condition&& condition, T_then&& then, T_else&& els) { // NOLINT
150151
as_operation_cl(std::forward<T_else>(els))};
151152
}
152153

153-
/**
154-
* Scalar overload of the selection operation.
155-
* @tparam T_then type of then scalar
156-
* @tparam T_else type of else scalar
157-
* @param condition condition
158-
* @param then then result
159-
* @param els else result
160-
* @return `condition ? then : els`
161-
*/
162-
template <typename T_then, typename T_else,
163-
require_all_arithmetic_t<T_then, T_else>* = nullptr>
164-
inline std::common_type_t<T_then, T_else> select(bool condition, T_then then,
165-
T_else els) {
166-
return condition ? then : els;
167-
}
168-
169154
/** @}*/
170155
} // namespace math
171156
} // namespace stan

stan/math/prim/fun.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <stan/math/prim/fun/acosh.hpp>
88
#include <stan/math/prim/fun/add.hpp>
99
#include <stan/math/prim/fun/add_diag.hpp>
10+
#include <stan/math/prim/fun/all.hpp>
11+
#include <stan/math/prim/fun/any.hpp>
1012
#include <stan/math/prim/fun/append_array.hpp>
1113
#include <stan/math/prim/fun/append_col.hpp>
1214
#include <stan/math/prim/fun/append_row.hpp>
@@ -62,6 +64,7 @@
6264
#include <stan/math/prim/fun/cov_matrix_free.hpp>
6365
#include <stan/math/prim/fun/cov_matrix_free_lkj.hpp>
6466
#include <stan/math/prim/fun/crossprod.hpp>
67+
#include <stan/math/prim/fun/csr_extract.hpp>
6568
#include <stan/math/prim/fun/csr_extract_u.hpp>
6669
#include <stan/math/prim/fun/csr_extract_v.hpp>
6770
#include <stan/math/prim/fun/csr_extract_w.hpp>
@@ -82,6 +85,8 @@
8285
#include <stan/math/prim/fun/dot_product.hpp>
8386
#include <stan/math/prim/fun/dot_self.hpp>
8487
#include <stan/math/prim/fun/eigen_comparisons.hpp>
88+
#include <stan/math/prim/fun/eigendecompose.hpp>
89+
#include <stan/math/prim/fun/eigendecompose_sym.hpp>
8590
#include <stan/math/prim/fun/eigenvalues.hpp>
8691
#include <stan/math/prim/fun/eigenvalues_sym.hpp>
8792
#include <stan/math/prim/fun/eigenvectors.hpp>
@@ -275,6 +280,7 @@
275280
#include <stan/math/prim/fun/qr.hpp>
276281
#include <stan/math/prim/fun/qr_Q.hpp>
277282
#include <stan/math/prim/fun/qr_R.hpp>
283+
#include <stan/math/prim/fun/qr_thin.hpp>
278284
#include <stan/math/prim/fun/qr_thin_Q.hpp>
279285
#include <stan/math/prim/fun/qr_thin_R.hpp>
280286
#include <stan/math/prim/fun/quad_form.hpp>
@@ -306,6 +312,7 @@
306312
#include <stan/math/prim/fun/sd.hpp>
307313
#include <stan/math/prim/fun/segment.hpp>
308314
#include <stan/math/prim/fun/serializer.hpp>
315+
#include <stan/math/prim/fun/select.hpp>
309316
#include <stan/math/prim/fun/sign.hpp>
310317
#include <stan/math/prim/fun/signbit.hpp>
311318
#include <stan/math/prim/fun/simplex_constrain.hpp>
@@ -331,6 +338,7 @@
331338
#include <stan/math/prim/fun/sub_row.hpp>
332339
#include <stan/math/prim/fun/subtract.hpp>
333340
#include <stan/math/prim/fun/sum.hpp>
341+
#include <stan/math/prim/fun/svd.hpp>
334342
#include <stan/math/prim/fun/svd_U.hpp>
335343
#include <stan/math/prim/fun/svd_V.hpp>
336344
#include <stan/math/prim/fun/symmetrize_from_lower_tri.hpp>

stan/math/prim/fun/all.hpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#ifndef STAN_MATH_PRIM_FUN_ALL_HPP
2+
#define STAN_MATH_PRIM_FUN_ALL_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/functor/for_each.hpp>
6+
#include <algorithm>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Return true if all values in the input are true.
13+
*
14+
* Overload for a single integral input
15+
*
16+
* @tparam T Any type convertible to `bool`
17+
* @param x integral input
18+
* @return The input unchanged
19+
*/
20+
template <typename T, require_t<std::is_convertible<T, bool>>* = nullptr>
21+
constexpr inline bool all(T x) {
22+
return x;
23+
}
24+
25+
/**
26+
* Return true if all values in the input are true.
27+
*
28+
* Overload for Eigen types
29+
*
30+
* @tparam ContainerT A type derived from `Eigen::EigenBase` that has an
31+
* `integral` scalar type
32+
* @param x Eigen object of boolean inputs
33+
* @return Boolean indicating whether all elements are true
34+
*/
35+
template <typename ContainerT,
36+
require_eigen_st<std::is_integral, ContainerT>* = nullptr>
37+
inline bool all(const ContainerT& x) {
38+
return x.all();
39+
}
40+
41+
// Forward-declaration for correct resolution of all(std::vector<std::tuple>)
42+
template <typename... Types>
43+
inline bool all(const std::tuple<Types...>& x);
44+
45+
/**
46+
* Return true if all values in the input are true.
47+
*
48+
* Overload for a std::vector/nested inputs. The Eigen::Map/apply_vector_unary
49+
* approach cannot be used as std::vector<bool> types do not have a .data()
50+
* member and are not always stored contiguously.
51+
*
52+
* @tparam InnerT Type within std::vector
53+
* @param x Nested container of boolean inputs
54+
* @return Boolean indicating whether all elements are true
55+
*/
56+
template <typename InnerT>
57+
inline bool all(const std::vector<InnerT>& x) {
58+
return std::all_of(x.begin(), x.end(), [](const auto& i) { return all(i); });
59+
}
60+
61+
/**
62+
* Return true if all values in the input are true.
63+
*
64+
* Overload for a tuple input.
65+
*
66+
* @tparam Types of items within tuple
67+
* @param x Tuple of boolean scalar-type elements
68+
* @return Boolean indicating whether all elements are true
69+
*/
70+
template <typename... Types>
71+
inline bool all(const std::tuple<Types...>& x) {
72+
bool all_true = true;
73+
math::for_each(
74+
[&all_true](const auto& i) {
75+
all_true = all_true && all(i);
76+
return;
77+
},
78+
x);
79+
return all_true;
80+
}
81+
82+
} // namespace math
83+
} // namespace stan
84+
85+
#endif

stan/math/prim/fun/any.hpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#ifndef STAN_MATH_PRIM_FUN_ANY_HPP
2+
#define STAN_MATH_PRIM_FUN_ANY_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/functor/for_each.hpp>
6+
#include <algorithm>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Return true if any values in the input are true.
13+
*
14+
* Overload for a single boolean input
15+
*
16+
* @tparam T Any type convertible to `bool`
17+
* @param x boolean input
18+
* @return The input unchanged
19+
*/
20+
template <typename T, require_t<std::is_convertible<T, bool>>* = nullptr>
21+
constexpr inline bool any(T x) {
22+
return x;
23+
}
24+
25+
/**
26+
* Return true if any values in the input are true.
27+
*
28+
* Overload for Eigen types
29+
*
30+
* @tparam ContainerT A type derived from `Eigen::EigenBase` that has an
31+
* `integral` scalar type
32+
* @param x Eigen object of boolean inputs
33+
* @return Boolean indicating whether any elements are true
34+
*/
35+
template <typename ContainerT,
36+
require_eigen_st<std::is_integral, ContainerT>* = nullptr>
37+
inline bool any(const ContainerT& x) {
38+
return x.any();
39+
}
40+
41+
// Forward-declaration for correct resolution of any(std::vector<std::tuple>)
42+
template <typename... Types>
43+
inline bool any(const std::tuple<Types...>& x);
44+
45+
/**
46+
* Return true if any values in the input are true.
47+
*
48+
* Overload for a std::vector/nested inputs. The Eigen::Map/apply_vector_unary
49+
* approach cannot be used as std::vector<bool> types do not have a .data()
50+
* member and are not always stored contiguously.
51+
*
52+
* @tparam InnerT Type within std::vector
53+
* @param x Nested container of boolean inputs
54+
* @return Boolean indicating whether any elements are true
55+
*/
56+
template <typename InnerT>
57+
inline bool any(const std::vector<InnerT>& x) {
58+
return std::any_of(x.begin(), x.end(), [](const auto& i) { return any(i); });
59+
}
60+
61+
/**
62+
* Return true if any values in the input are true.
63+
*
64+
* Overload for a tuple input.
65+
*
66+
* @tparam Types of items within tuple
67+
* @param x Tuple of boolean scalar-type elements
68+
* @return Boolean indicating whether any elements are true
69+
*/
70+
template <typename... Types>
71+
inline bool any(const std::tuple<Types...>& x) {
72+
bool any_true = false;
73+
math::for_each(
74+
[&any_true](const auto& i) {
75+
any_true = any_true || any(i);
76+
return;
77+
},
78+
x);
79+
return any_true;
80+
}
81+
82+
} // namespace math
83+
} // namespace stan
84+
85+
#endif

stan/math/prim/fun/complex_schur_decompose.hpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ namespace math {
2727
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
2828
inline Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>
2929
complex_schur_decompose_u(const M& m) {
30-
if (m.size() == 0)
30+
if (unlikely(m.size() == 0)) {
3131
return m;
32+
}
3233
check_square("complex_schur_decompose_u", "m", m);
3334
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
3435
// copy because ComplexSchur requires Eigen::Matrix type
35-
MatType mv = m;
36-
Eigen::ComplexSchur<MatType> cs(mv);
36+
Eigen::ComplexSchur<MatType> cs{MatType(m)};
3737
return cs.matrixU();
3838
}
3939

@@ -51,16 +51,46 @@ complex_schur_decompose_u(const M& m) {
5151
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
5252
inline Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>
5353
complex_schur_decompose_t(const M& m) {
54-
if (m.size() == 0)
54+
if (unlikely(m.size() == 0)) {
5555
return m;
56+
}
5657
check_square("complex_schur_decompose_t", "m", m);
5758
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
5859
// copy because ComplexSchur requires Eigen::Matrix type
59-
MatType mv = m;
60-
Eigen::ComplexSchur<MatType> cs(mv, false);
60+
Eigen::ComplexSchur<MatType> cs{MatType(m), false};
6161
return cs.matrixT();
6262
}
6363

64+
/**
65+
* Return the complex Schur decomposition of the
66+
* specified square matrix.
67+
*
68+
* The complex Schur decomposition of a square matrix `A` produces a
69+
* complex unitary matrix `U` and a complex upper-triangular Schur
70+
* form matrix `T` such that `A = U * T * inv(U)`. Further, the
71+
* unitary matrix's inverse is equal to its conjugate transpose,
72+
* `inv(U) = U*`, where `U*(i, j) = conj(U(j, i))`
73+
*
74+
* @tparam M type of matrix
75+
* @param m real matrix to decompose
76+
* @return a tuple (U,T) where U is the complex unitary matrix of the complex
77+
* Schur decomposition of `m` and T is the Schur form matrix of
78+
* the complex Schur decomposition of `m`
79+
*/
80+
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
81+
inline std::tuple<Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>,
82+
Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>>
83+
complex_schur_decompose(const M& m) {
84+
if (unlikely(m.size() == 0)) {
85+
return std::make_tuple(m, m);
86+
}
87+
check_square("complex_schur_decompose", "m", m);
88+
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
89+
// copy because ComplexSchur requires Eigen::Matrix type
90+
Eigen::ComplexSchur<MatType> cs{MatType(m)};
91+
return std::make_tuple(std::move(cs.matrixU()), std::move(cs.matrixT()));
92+
}
93+
6494
} // namespace math
6595
} // namespace stan
6696
#endif

stan/math/prim/fun/csr_extract.hpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef STAN_MATH_PRIM_FUN_CSR_EXTRACT_HPP
2+
#define STAN_MATH_PRIM_FUN_CSR_EXTRACT_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/fun/to_ref.hpp>
6+
7+
namespace stan {
8+
namespace math {
9+
10+
/** \addtogroup csr_format
11+
* @{
12+
*/
13+
14+
/**
15+
* Extract the non-zero values, column indexes for non-zero values, and
16+
* the NZE index for each entry from a sparse matrix.
17+
*
18+
* @tparam T type of elements in the matrix
19+
* @param[in] A sparse matrix.
20+
* @return a tuple W,V,U.
21+
*/
22+
template <typename T>
23+
const std::tuple<Eigen::Matrix<T, Eigen::Dynamic, 1>, std::vector<int>,
24+
std::vector<int>>
25+
csr_extract(const Eigen::SparseMatrix<T, Eigen::RowMajor>& A) {
26+
auto a_nonzeros = A.nonZeros();
27+
Eigen::Matrix<T, Eigen::Dynamic, 1> w
28+
= Eigen::Matrix<T, Eigen::Dynamic, 1>::Zero(a_nonzeros);
29+
std::vector<int> v(a_nonzeros);
30+
for (int nze = 0; nze < a_nonzeros; ++nze) {
31+
w[nze] = *(A.valuePtr() + nze);
32+
v[nze] = *(A.innerIndexPtr() + nze) + stan::error_index::value;
33+
}
34+
std::vector<int> u(A.outerSize() + 1); // last entry is garbage.
35+
for (int nze = 0; nze <= A.outerSize(); ++nze) {
36+
u[nze] = *(A.outerIndexPtr() + nze) + stan::error_index::value;
37+
}
38+
return std::make_tuple(std::move(w), std::move(v), std::move(u));
39+
}
40+
41+
/* Extract the non-zero values from a dense matrix by converting
42+
* to sparse and calling the sparse matrix extractor.
43+
*
44+
* @tparam T type of elements in the matrix
45+
* @tparam R number of rows, can be Eigen::Dynamic
46+
* @tparam C number of columns, can be Eigen::Dynamic
47+
*
48+
* @param[in] A dense matrix.
49+
* @return a tuple W,V,U.
50+
*/
51+
template <typename T, require_eigen_dense_base_t<T>* = nullptr>
52+
const std::tuple<Eigen::Matrix<scalar_type_t<T>, Eigen::Dynamic, 1>,
53+
std::vector<int>, std::vector<int>>
54+
csr_extract(const T& A) {
55+
// conversion to sparse seems to touch data twice, so we need to call to_ref
56+
Eigen::SparseMatrix<scalar_type_t<T>, Eigen::RowMajor> B
57+
= to_ref(A).sparseView();
58+
return csr_extract(B);
59+
}
60+
61+
/** @} */ // end of csr_format group
62+
63+
} // namespace math
64+
} // namespace stan
65+
66+
#endif

0 commit comments

Comments
 (0)