Skip to content

Commit 5d1fd38

Browse files
authored
Merge pull request #2931 from stan-dev/feature/2845-tuple-fns
Add tuple-returning special functions
2 parents d4eab27 + 497dc71 commit 5d1fd38

53 files changed

Lines changed: 1194 additions & 71 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

stan/math/prim/fun.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
#include <stan/math/prim/fun/cov_matrix_free.hpp>
6565
#include <stan/math/prim/fun/cov_matrix_free_lkj.hpp>
6666
#include <stan/math/prim/fun/crossprod.hpp>
67+
#include <stan/math/prim/fun/csr_extract.hpp>
6768
#include <stan/math/prim/fun/csr_extract_u.hpp>
6869
#include <stan/math/prim/fun/csr_extract_v.hpp>
6970
#include <stan/math/prim/fun/csr_extract_w.hpp>
@@ -84,6 +85,8 @@
8485
#include <stan/math/prim/fun/dot_product.hpp>
8586
#include <stan/math/prim/fun/dot_self.hpp>
8687
#include <stan/math/prim/fun/eigen_comparisons.hpp>
88+
#include <stan/math/prim/fun/eigendecompose.hpp>
89+
#include <stan/math/prim/fun/eigendecompose_sym.hpp>
8790
#include <stan/math/prim/fun/eigenvalues.hpp>
8891
#include <stan/math/prim/fun/eigenvalues_sym.hpp>
8992
#include <stan/math/prim/fun/eigenvectors.hpp>
@@ -277,6 +280,7 @@
277280
#include <stan/math/prim/fun/qr.hpp>
278281
#include <stan/math/prim/fun/qr_Q.hpp>
279282
#include <stan/math/prim/fun/qr_R.hpp>
283+
#include <stan/math/prim/fun/qr_thin.hpp>
280284
#include <stan/math/prim/fun/qr_thin_Q.hpp>
281285
#include <stan/math/prim/fun/qr_thin_R.hpp>
282286
#include <stan/math/prim/fun/quad_form.hpp>
@@ -333,6 +337,7 @@
333337
#include <stan/math/prim/fun/sub_row.hpp>
334338
#include <stan/math/prim/fun/subtract.hpp>
335339
#include <stan/math/prim/fun/sum.hpp>
340+
#include <stan/math/prim/fun/svd.hpp>
336341
#include <stan/math/prim/fun/svd_U.hpp>
337342
#include <stan/math/prim/fun/svd_V.hpp>
338343
#include <stan/math/prim/fun/symmetrize_from_lower_tri.hpp>

stan/math/prim/fun/complex_schur_decompose.hpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ namespace math {
2727
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
2828
inline Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>
2929
complex_schur_decompose_u(const M& m) {
30-
if (m.size() == 0)
30+
if (unlikely(m.size() == 0)) {
3131
return m;
32+
}
3233
check_square("complex_schur_decompose_u", "m", m);
3334
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
3435
// copy because ComplexSchur requires Eigen::Matrix type
35-
MatType mv = m;
36-
Eigen::ComplexSchur<MatType> cs(mv);
36+
Eigen::ComplexSchur<MatType> cs{MatType(m)};
3737
return cs.matrixU();
3838
}
3939

@@ -51,16 +51,46 @@ complex_schur_decompose_u(const M& m) {
5151
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
5252
inline Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>
5353
complex_schur_decompose_t(const M& m) {
54-
if (m.size() == 0)
54+
if (unlikely(m.size() == 0)) {
5555
return m;
56+
}
5657
check_square("complex_schur_decompose_t", "m", m);
5758
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
5859
// copy because ComplexSchur requires Eigen::Matrix type
59-
MatType mv = m;
60-
Eigen::ComplexSchur<MatType> cs(mv, false);
60+
Eigen::ComplexSchur<MatType> cs{MatType(m), false};
6161
return cs.matrixT();
6262
}
6363

64+
/**
65+
* Return the complex Schur decomposition of the
66+
* specified square matrix.
67+
*
68+
* The complex Schur decomposition of a square matrix `A` produces a
69+
* complex unitary matrix `U` and a complex upper-triangular Schur
70+
* form matrix `T` such that `A = U * T * inv(U)`. Further, the
71+
* unitary matrix's inverse is equal to its conjugate transpose,
72+
* `inv(U) = U*`, where `U*(i, j) = conj(U(j, i))`
73+
*
74+
* @tparam M type of matrix
75+
* @param m real matrix to decompose
76+
* @return a tuple (U,T) where U is the complex unitary matrix of the complex
77+
* Schur decomposition of `m` and T is the Schur form matrix of
78+
* the complex Schur decomposition of `m`
79+
*/
80+
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
81+
inline std::tuple<Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>,
82+
Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>>
83+
complex_schur_decompose(const M& m) {
84+
if (unlikely(m.size() == 0)) {
85+
return std::make_tuple(m, m);
86+
}
87+
check_square("complex_schur_decompose", "m", m);
88+
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
89+
// copy because ComplexSchur requires Eigen::Matrix type
90+
Eigen::ComplexSchur<MatType> cs{MatType(m)};
91+
return std::make_tuple(std::move(cs.matrixU()), std::move(cs.matrixT()));
92+
}
93+
6494
} // namespace math
6595
} // namespace stan
6696
#endif

stan/math/prim/fun/csr_extract.hpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef STAN_MATH_PRIM_FUN_CSR_EXTRACT_HPP
2+
#define STAN_MATH_PRIM_FUN_CSR_EXTRACT_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/fun/to_ref.hpp>
6+
7+
namespace stan {
8+
namespace math {
9+
10+
/** \addtogroup csr_format
11+
* @{
12+
*/
13+
14+
/**
15+
* Extract the non-zero values, column indexes for non-zero values, and
16+
* the NZE index for each entry from a sparse matrix.
17+
*
18+
* @tparam T type of elements in the matrix
19+
* @param[in] A sparse matrix.
20+
* @return a tuple W,V,U.
21+
*/
22+
template <typename T>
23+
const std::tuple<Eigen::Matrix<T, Eigen::Dynamic, 1>, std::vector<int>,
24+
std::vector<int>>
25+
csr_extract(const Eigen::SparseMatrix<T, Eigen::RowMajor>& A) {
26+
auto a_nonzeros = A.nonZeros();
27+
Eigen::Matrix<T, Eigen::Dynamic, 1> w
28+
= Eigen::Matrix<T, Eigen::Dynamic, 1>::Zero(a_nonzeros);
29+
std::vector<int> v(a_nonzeros);
30+
for (int nze = 0; nze < a_nonzeros; ++nze) {
31+
w[nze] = *(A.valuePtr() + nze);
32+
v[nze] = *(A.innerIndexPtr() + nze) + stan::error_index::value;
33+
}
34+
std::vector<int> u(A.outerSize() + 1); // last entry is garbage.
35+
for (int nze = 0; nze <= A.outerSize(); ++nze) {
36+
u[nze] = *(A.outerIndexPtr() + nze) + stan::error_index::value;
37+
}
38+
return std::make_tuple(std::move(w), std::move(v), std::move(u));
39+
}
40+
41+
/* Extract the non-zero values from a dense matrix by converting
42+
* to sparse and calling the sparse matrix extractor.
43+
*
44+
* @tparam T type of elements in the matrix
45+
* @tparam R number of rows, can be Eigen::Dynamic
46+
* @tparam C number of columns, can be Eigen::Dynamic
47+
*
48+
* @param[in] A dense matrix.
49+
* @return a tuple W,V,U.
50+
*/
51+
template <typename T, require_eigen_dense_base_t<T>* = nullptr>
52+
const std::tuple<Eigen::Matrix<scalar_type_t<T>, Eigen::Dynamic, 1>,
53+
std::vector<int>, std::vector<int>>
54+
csr_extract(const T& A) {
55+
// conversion to sparse seems to touch data twice, so we need to call to_ref
56+
Eigen::SparseMatrix<scalar_type_t<T>, Eigen::RowMajor> B
57+
= to_ref(A).sparseView();
58+
return csr_extract(B);
59+
}
60+
61+
/** @} */ // end of csr_format group
62+
63+
} // namespace math
64+
} // namespace stan
65+
66+
#endif

stan/math/prim/fun/csr_extract_w.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ namespace math {
2020
template <typename T>
2121
const Eigen::Matrix<T, Eigen::Dynamic, 1> csr_extract_w(
2222
const Eigen::SparseMatrix<T, Eigen::RowMajor>& A) {
23-
Eigen::Matrix<T, Eigen::Dynamic, 1> w(A.nonZeros());
24-
w.setZero();
25-
for (int nze = 0; nze < A.nonZeros(); ++nze) {
23+
auto a_nonzeros = A.nonZeros();
24+
Eigen::Matrix<T, Eigen::Dynamic, 1> w
25+
= Eigen::Matrix<T, Eigen::Dynamic, 1>::Zero(a_nonzeros);
26+
for (int nze = 0; nze < a_nonzeros; ++nze) {
2627
w[nze] = *(A.valuePtr() + nze);
2728
}
2829
return w;
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#ifndef STAN_MATH_PRIM_FUN_EIGENDECOMPOSE_HPP
2+
#define STAN_MATH_PRIM_FUN_EIGENDECOMPOSE_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
7+
namespace stan {
8+
namespace math {
9+
10+
/**
11+
* Return the eigendecomposition of a (real-valued) matrix
12+
*
13+
* @tparam EigMat type of real matrix argument
14+
* @param[in] m matrix to find the eigendecomposition of. Must be square and
15+
* have a non-zero size.
16+
* @return A tuple V,D where V is a matrix where the columns are the
17+
* complex-valued eigenvectors of `m` and D is a complex-valued column vector
18+
* with entries the eigenvectors of `m`
19+
*/
20+
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr,
21+
require_not_vt_complex<EigMat>* = nullptr>
22+
inline std::tuple<Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, -1>,
23+
Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, 1>>
24+
eigendecompose(const EigMat& m) {
25+
if (unlikely(m.size() == 0)) {
26+
return std::make_tuple(
27+
Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, -1>(0, 0),
28+
Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, 1>(0, 1));
29+
}
30+
check_square("eigendecompose", "m", m);
31+
32+
using PlainMat = plain_type_t<EigMat>;
33+
const PlainMat& m_eval = m;
34+
35+
Eigen::EigenSolver<PlainMat> solver(m_eval);
36+
return std::make_tuple(std::move(solver.eigenvectors()),
37+
std::move(solver.eigenvalues()));
38+
}
39+
40+
/**
41+
* Return the eigendecomposition of a (complex-valued) matrix
42+
*
43+
* @tparam EigCplxMat type of complex matrix argument
44+
* @param[in] m matrix to find the eigendecomposition of. Must be square and
45+
* have a non-zero size.
46+
* @return A tuple V,D where V is a matrix where the columns are the
47+
* complex-valued eigenvectors of `m` and D is a complex-valued column vector
48+
* with entries the eigenvectors of `m`
49+
*/
50+
template <typename EigCplxMat,
51+
require_eigen_matrix_dynamic_vt<is_complex, EigCplxMat>* = nullptr>
52+
inline std::tuple<
53+
Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, -1>,
54+
Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, 1>>
55+
eigendecompose(const EigCplxMat& m) {
56+
if (unlikely(m.size() == 0)) {
57+
return std::make_tuple(
58+
Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, -1>(0, 0),
59+
Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, 1>(0, 1));
60+
}
61+
check_square("eigendecompose", "m", m);
62+
63+
using PlainMat = Eigen::Matrix<scalar_type_t<EigCplxMat>, -1, -1>;
64+
const PlainMat& m_eval = m;
65+
66+
Eigen::ComplexEigenSolver<PlainMat> solver(m_eval);
67+
68+
return std::make_tuple(std::move(solver.eigenvectors()),
69+
std::move(solver.eigenvalues()));
70+
}
71+
72+
} // namespace math
73+
} // namespace stan
74+
#endif
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#ifndef STAN_MATH_PRIM_FUN_EIGENDECOMPOSE_SYM_HPP
2+
#define STAN_MATH_PRIM_FUN_EIGENDECOMPOSE_SYM_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/Eigen.hpp>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Return the eigendecomposition of the specified symmetric matrix.
13+
*
14+
* @tparam EigMat type of the matrix
15+
* @param m Specified matrix.
16+
* @return A tuple V,D where V is a matrix where the columns are the
17+
* eigenvectors of m, and D is a column vector of the eigenvalues of m.
18+
* The eigenvalues are in ascending order of magnitude, with the eigenvectors
19+
* provided in the same order.
20+
*/
21+
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
22+
require_not_st_var<EigMat>* = nullptr>
23+
std::tuple<Eigen::Matrix<value_type_t<EigMat>, -1, -1>,
24+
Eigen::Matrix<value_type_t<EigMat>, -1, 1>>
25+
eigendecompose_sym(const EigMat& m) {
26+
if (unlikely(m.size() == 0)) {
27+
return std::make_tuple(Eigen::Matrix<value_type_t<EigMat>, -1, -1>(0, 0),
28+
Eigen::Matrix<value_type_t<EigMat>, -1, 1>(0, 1));
29+
}
30+
using PlainMat = plain_type_t<EigMat>;
31+
const PlainMat& m_eval = m;
32+
check_symmetric("eigendecompose_sym", "m", m_eval);
33+
34+
Eigen::SelfAdjointEigenSolver<PlainMat> solver(m_eval);
35+
return std::make_tuple(std::move(solver.eigenvectors()),
36+
std::move(solver.eigenvalues()));
37+
}
38+
39+
} // namespace math
40+
} // namespace stan
41+
#endif

stan/math/prim/fun/eigenvalues.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr,
1919
require_not_vt_complex<EigMat>* = nullptr>
2020
inline Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, 1> eigenvalues(
2121
const EigMat& m) {
22+
if (unlikely(m.size() == 0)) {
23+
return Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, 1>(0, 1);
24+
}
25+
check_square("eigenvalues", "m", m);
2226
using PlainMat = plain_type_t<EigMat>;
2327
const PlainMat& m_eval = m;
24-
check_nonzero_size("eigenvalues", "m", m_eval);
25-
check_square("eigenvalues", "m", m_eval);
2628

2729
Eigen::EigenSolver<PlainMat> solver(m_eval, false);
2830
return solver.eigenvalues();
@@ -37,14 +39,16 @@ inline Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, 1> eigenvalues(
3739
* @return a complex-valued column vector with entries the eigenvectors of `m`
3840
*/
3941
template <typename EigCplxMat,
40-
require_eigen_matrix_dynamic_t<EigCplxMat>* = nullptr,
41-
require_vt_complex<EigCplxMat>* = nullptr>
42+
require_eigen_matrix_dynamic_vt<is_complex, EigCplxMat>* = nullptr>
4243
inline Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, 1>
4344
eigenvalues(const EigCplxMat& m) {
45+
if (unlikely(m.size() == 0)) {
46+
return Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, 1>(0,
47+
1);
48+
}
49+
check_square("eigenvalues", "m", m);
4450
using PlainMat = Eigen::Matrix<scalar_type_t<EigCplxMat>, -1, -1>;
4551
const PlainMat& m_eval = m;
46-
check_nonzero_size("eigenvalues", "m", m_eval);
47-
check_square("eigenvalues", "m", m_eval);
4852

4953
Eigen::ComplexEigenSolver<PlainMat> solver(m_eval, false);
5054

stan/math/prim/fun/eigenvalues_sym.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ namespace math {
1010

1111
/**
1212
* Return the eigenvalues of the specified symmetric matrix
13-
* in descending order of magnitude. This function is more
13+
* in ascending order of magnitude. This function is more
1414
* efficient than the general eigenvalues function for symmetric
1515
* matrices.
16-
* <p>See <code>eigen_decompose()</code> for more information.
1716
*
1817
* @tparam EigMat type of the matrix
1918
* @param m Specified matrix.
@@ -22,9 +21,11 @@ namespace math {
2221
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr,
2322
require_not_st_var<EigMat>* = nullptr>
2423
Eigen::Matrix<value_type_t<EigMat>, -1, 1> eigenvalues_sym(const EigMat& m) {
24+
if (unlikely(m.size() == 0)) {
25+
return Eigen::Matrix<value_type_t<EigMat>, -1, 1>(0, 1);
26+
}
2527
using PlainMat = plain_type_t<EigMat>;
2628
const PlainMat& m_eval = m;
27-
check_nonzero_size("eigenvalues_sym", "m", m_eval);
2829
check_symmetric("eigenvalues_sym", "m", m_eval);
2930

3031
Eigen::SelfAdjointEigenSolver<PlainMat> solver(m_eval,

stan/math/prim/fun/eigenvectors.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr,
2020
require_not_vt_complex<EigMat>* = nullptr>
2121
inline Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, -1>
2222
eigenvectors(const EigMat& m) {
23+
if (unlikely(m.size() == 0)) {
24+
return Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, -1>(0, 0);
25+
}
26+
check_square("eigenvectors", "m", m);
2327
using PlainMat = plain_type_t<EigMat>;
2428
const PlainMat& m_eval = m;
25-
check_nonzero_size("eigenvectors", "m", m_eval);
26-
check_square("eigenvectors", "m", m_eval);
2729

2830
Eigen::EigenSolver<PlainMat> solver(m_eval);
2931
return solver.eigenvectors();
@@ -39,14 +41,16 @@ eigenvectors(const EigMat& m) {
3941
* `m`
4042
*/
4143
template <typename EigCplxMat,
42-
require_eigen_matrix_dynamic_t<EigCplxMat>* = nullptr,
43-
require_vt_complex<EigCplxMat>* = nullptr>
44+
require_eigen_matrix_dynamic_vt<is_complex, EigCplxMat>* = nullptr>
4445
inline Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, -1>
4546
eigenvectors(const EigCplxMat& m) {
47+
if (unlikely(m.size() == 0)) {
48+
return Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, -1>(0,
49+
0);
50+
}
51+
check_square("eigenvectors", "m", m);
4652
using PlainMat = Eigen::Matrix<scalar_type_t<EigCplxMat>, -1, -1>;
4753
const PlainMat& m_eval = m;
48-
check_nonzero_size("eigenvectors", "m", m_eval);
49-
check_square("eigenvectors", "m", m_eval);
5054

5155
Eigen::ComplexEigenSolver<PlainMat> solver(m_eval);
5256
return solver.eigenvectors();

0 commit comments

Comments
 (0)