Skip to content

Commit edfc5b8

Browse files
authored
Merge pull request #2846 from stan-dev/feature/2733-complex-eigendecomposition
Add overloads to eigendecomposition for complex matrices
2 parents 133fd01 + bc3a33e commit edfc5b8

4 files changed

Lines changed: 129 additions & 14 deletions

File tree

stan/math/prim/fun/eigenvalues.hpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,47 @@
77
namespace stan {
88
namespace math {
99

10-
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr>
11-
inline auto eigenvalues(const EigMat& m) {
10+
/**
11+
* Return the eigenvalues of a (real-valued) matrix
12+
*
13+
* @tparam EigMat type of real matrix argument
14+
* @param[in] m matrix to find the eigenvectors of. Must be square and have a
15+
* non-zero size.
16+
* @return a complex-valued column vector with entries the eigenvectors of `m`
17+
*/
18+
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr,
19+
require_not_vt_complex<EigMat>* = nullptr>
20+
inline Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, 1, -1> eigenvalues(
21+
const EigMat& m) {
1222
using PlainMat = plain_type_t<EigMat>;
1323
const PlainMat& m_eval = m;
1424
check_nonzero_size("eigenvalues", "m", m_eval);
1525
check_square("eigenvalues", "m", m_eval);
1626

17-
Eigen::EigenSolver<PlainMat> solver(m_eval);
27+
Eigen::EigenSolver<PlainMat> solver(m_eval, false);
28+
return solver.eigenvalues();
29+
}
30+
31+
/**
32+
* Return the eigenvalues of a (complex-valued) matrix
33+
*
34+
* @tparam EigCplxMat type of complex matrix argument
35+
* @param[in] m matrix to find the eigenvectors of. Must be square and have a
36+
* non-zero size.
37+
* @return a complex-valued column vector with entries the eigenvectors of `m`
38+
*/
39+
template <typename EigCplxMat,
40+
require_eigen_matrix_dynamic_t<EigCplxMat>* = nullptr,
41+
require_vt_complex<EigCplxMat>* = nullptr>
42+
inline Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, 1, -1>
43+
eigenvalues(const EigCplxMat& m) {
44+
using PlainMat = Eigen::Matrix<scalar_type_t<EigCplxMat>, -1, -1>;
45+
const PlainMat& m_eval = m;
46+
check_nonzero_size("eigenvalues", "m", m_eval);
47+
check_square("eigenvalues", "m", m_eval);
48+
49+
Eigen::ComplexEigenSolver<PlainMat> solver(m_eval, false);
50+
1851
return solver.eigenvalues();
1952
}
2053

stan/math/prim/fun/eigenvectors.hpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,19 @@
77
namespace stan {
88
namespace math {
99

10-
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr>
11-
inline auto eigenvectors(const EigMat& m) {
10+
/**
11+
* Return the eigenvectors of a (real-valued) matrix
12+
*
13+
* @tparam EigMat type of real matrix argument
14+
* @param[in] m matrix to find the eigenvectors of. Must be square and have a
15+
* non-zero size.
16+
* @return a complex-valued matrix with columns representing the eigenvectors of
17+
* `m`
18+
*/
19+
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr,
20+
require_not_vt_complex<EigMat>* = nullptr>
21+
inline Eigen::Matrix<complex_return_t<value_type_t<EigMat>>, -1, -1>
22+
eigenvectors(const EigMat& m) {
1223
using PlainMat = plain_type_t<EigMat>;
1324
const PlainMat& m_eval = m;
1425
check_nonzero_size("eigenvectors", "m", m_eval);
@@ -18,6 +29,29 @@ inline auto eigenvectors(const EigMat& m) {
1829
return solver.eigenvectors();
1930
}
2031

32+
/**
33+
* Return the eigenvectors of a (complex-valued) matrix
34+
*
35+
* @tparam EigCplxMat type of complex matrix argument
36+
* @param[in] m matrix to find the eigenvectors of. Must be square and have a
37+
* non-zero size.
38+
* @return a complex-valued matrix with columns representing the eigenvectors of
39+
* `m`
40+
*/
41+
template <typename EigCplxMat,
42+
require_eigen_matrix_dynamic_t<EigCplxMat>* = nullptr,
43+
require_vt_complex<EigCplxMat>* = nullptr>
44+
inline Eigen::Matrix<complex_return_t<value_type_t<EigCplxMat>>, -1, -1>
45+
eigenvectors(const EigCplxMat& m) {
46+
using PlainMat = Eigen::Matrix<scalar_type_t<EigCplxMat>, -1, -1>;
47+
const PlainMat& m_eval = m;
48+
check_nonzero_size("eigenvectors", "m", m_eval);
49+
check_square("eigenvectors", "m", m_eval);
50+
51+
Eigen::ComplexEigenSolver<PlainMat> solver(m_eval);
52+
return solver.eigenvectors();
53+
}
54+
2155
} // namespace math
2256
} // namespace stan
2357
#endif

test/unit/math/mix/fun/eigenvalues_test.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,19 @@ TEST(mathMixFun, eigenvalues) {
1515
EXPECT_THROW(f(a32), std::invalid_argument);
1616
}
1717

18+
TEST(mathMixFun, eigenvaluesComplex) {
19+
auto f = [](const auto& x) {
20+
using stan::math::eigenvalues;
21+
return eigenvalues(stan::math::to_complex(x, 0));
22+
};
23+
for (const auto& x : stan::test::square_test_matrices(0, 2)) {
24+
stan::test::expect_ad(f, x);
25+
}
26+
27+
Eigen::MatrixXd a32(3, 2);
28+
a32 << 3, -5, 7, -7.2, 9.1, -6.3;
29+
EXPECT_THROW(f(a32), std::invalid_argument);
30+
}
31+
1832
// see eigenvectors_test.cpp for test of eigenvectors() and eigenvalues()
1933
// using reconstruction identities

test/unit/math/mix/fun/eigenvectors_test.cpp

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,33 @@ TEST(mathMixFun, eigenvectors) {
1515
EXPECT_THROW(f(a32), std::invalid_argument);
1616
}
1717

18+
TEST(mathMixFun, eigenvectorsComplex) {
19+
auto f = [](const auto& x) {
20+
using stan::math::eigenvectors;
21+
return eigenvectors(stan::math::to_complex(x, 0));
22+
};
23+
for (const auto& x : stan::test::square_test_matrices(0, 2)) {
24+
stan::test::expect_ad(f, x);
25+
}
26+
27+
Eigen::MatrixXd a32(3, 2);
28+
a32 << 3, -5, 7, -7.2, 9.1, -6.3;
29+
EXPECT_THROW(f(a32), std::invalid_argument);
30+
}
31+
1832
template <typename T>
1933
void expect_identity_matrix(const T& x) {
20-
EXPECT_EQUAL(x.rows(), x.cols());
34+
EXPECT_EQ(x.rows(), x.cols());
2135
for (int j = 0; j < x.cols(); ++j) {
2236
for (int i = 0; i < x.rows(); ++i) {
23-
EXPECT_NEAR(i == j ? 1 : 0, x(i, j), 1e-6);
37+
EXPECT_NEAR(i == j ? 1 : 0, stan::math::value_of_rec(x(i, j)), 1e-6);
2438
}
2539
}
2640
}
2741

2842
template <typename T>
2943
void expectEigenvectorsId() {
30-
for (const auto& m_d : stan::test::square_test_matrices(0, 2)) {
44+
for (const auto& m_d : stan::test::square_test_matrices(1, 2)) {
3145
Eigen::Matrix<T, -1, -1> m(m_d);
3246
auto vecs = eigenvectors(m).eval();
3347
auto vals = eigenvalues(m).eval();
@@ -36,7 +50,21 @@ void expectEigenvectorsId() {
3650
}
3751
}
3852

39-
// THESE TESTS USED TO WORK STANDALONE
53+
template <typename T>
54+
void expectComplexEigenvectorsId() {
55+
Eigen::Matrix<std::complex<T>, -1, -1> c22(2, 2);
56+
c22 << stan::math::to_complex(T(0), T(-1)),
57+
stan::math::to_complex(T(0), T(0)), stan::math::to_complex(T(2), T(0)),
58+
stan::math::to_complex(T(4), T(0));
59+
auto eigenvalues = stan::math::eigenvalues(c22);
60+
auto eigenvectors = stan::math::eigenvectors(c22);
61+
62+
auto I = (eigenvectors.inverse() * c22 * eigenvectors
63+
* eigenvalues.asDiagonal().inverse())
64+
.real();
65+
66+
expect_identity_matrix(I);
67+
}
4068

4169
TEST(mathMixFun, eigenvectorsId) {
4270
using d_t = double;
@@ -46,9 +74,15 @@ TEST(mathMixFun, eigenvectorsId) {
4674
using fv_t = stan::math::fvar<stan::math::var>;
4775
using ffv_t = stan::math::fvar<fv_t>;
4876

49-
// expectEigenvectorsId<v_t>();
50-
// expectEigenvectorsId<fd_t>();
51-
// expectEigenvectorsId<ffd_t>();
52-
// expectEigenvectorsId<fv_t>();
53-
// expectEigenvectorsId<ffv_t>();
77+
expectEigenvectorsId<v_t>();
78+
expectEigenvectorsId<fd_t>();
79+
expectEigenvectorsId<ffd_t>();
80+
expectEigenvectorsId<fv_t>();
81+
expectEigenvectorsId<ffv_t>();
82+
83+
expectComplexEigenvectorsId<v_t>();
84+
expectComplexEigenvectorsId<fd_t>();
85+
expectComplexEigenvectorsId<ffd_t>();
86+
expectComplexEigenvectorsId<fv_t>();
87+
expectComplexEigenvectorsId<ffv_t>();
5488
}

0 commit comments

Comments
 (0)