Skip to content

Commit bb657f5

Browse files
committed
Update svd
1 parent 7730846 commit bb657f5

1 file changed

Lines changed: 23 additions & 21 deletions

File tree

stan/math/rev/fun/svd.hpp

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,30 +60,32 @@ inline auto svd(const EigMat& m) {
6060
arena_Fm, M]() mutable {
6161
// SVD-U reverse mode
6262
Eigen::MatrixXd UUadjT = arena_U.val_op().transpose() * arena_U.adj_op();
63-
arena_m.adj()
64-
+= .5 * arena_U.val_op()
65-
* (arena_Fp.array() * (UUadjT - UUadjT.transpose()).array())
66-
.matrix()
67-
* arena_V.val_op().transpose()
68-
+ (Eigen::MatrixXd::Identity(arena_m.rows(), arena_m.rows())
69-
- arena_U.val_op() * arena_U.val_op().transpose())
70-
* arena_U.adj_op()
71-
* singular_values.val_op().asDiagonal().inverse()
72-
* arena_V.val_op().transpose();
63+
auto u_adj
64+
= .5 * arena_U.val_op()
65+
* (arena_Fp.array() * (UUadjT - UUadjT.transpose()).array())
66+
.matrix()
67+
* arena_V.val_op().transpose()
68+
+ (Eigen::MatrixXd::Identity(arena_m.rows(), arena_m.rows())
69+
- arena_U.val_op() * arena_U.val_op().transpose())
70+
* arena_U.adj_op()
71+
* singular_values.val_op().asDiagonal().inverse()
72+
* arena_V.val_op().transpose();
7373
// Singular values reverse mode
74-
arena_m.adj() += arena_U.val_op() * singular_values.adj().asDiagonal()
75-
* arena_V.val_op().transpose();
74+
auto d_adj = arena_U.val_op() * singular_values.adj().asDiagonal()
75+
* arena_V.val_op().transpose();
7676
// SVD-V reverse mode
7777
Eigen::MatrixXd VTVadj = arena_V.val_op().transpose() * arena_V.adj_op();
78-
arena_m.adj()
79-
+= 0.5 * arena_U.val_op()
80-
* (arena_Fm.array() * (VTVadj - VTVadj.transpose()).array())
81-
.matrix()
82-
* arena_V.val_op().transpose()
83-
+ arena_U.val_op() * singular_values.val_op().asDiagonal().inverse()
84-
* arena_V.adj_op().transpose()
85-
* (Eigen::MatrixXd::Identity(arena_m.cols(), arena_m.cols())
86-
- arena_V.val_op() * arena_V.val_op().transpose());
78+
auto v_adj
79+
= 0.5 * arena_U.val_op()
80+
* (arena_Fm.array() * (VTVadj - VTVadj.transpose()).array())
81+
.matrix()
82+
* arena_V.val_op().transpose()
83+
+ arena_U.val_op() * singular_values.val_op().asDiagonal().inverse()
84+
* arena_V.adj_op().transpose()
85+
* (Eigen::MatrixXd::Identity(arena_m.cols(), arena_m.cols())
86+
- arena_V.val_op() * arena_V.val_op().transpose());
87+
88+
arena_m.adj() += u_adj + d_adj + v_adj;
8789
});
8890

8991
return std::make_tuple(mat_ret_type(arena_U), vec_ret_type(singular_values),

0 commit comments

Comments
 (0)