@@ -60,30 +60,32 @@ inline auto svd(const EigMat& m) {
6060 arena_Fm, M]() mutable {
6161 // SVD-U reverse mode
6262 Eigen::MatrixXd UUadjT = arena_U.val_op ().transpose () * arena_U.adj_op ();
63- arena_m. adj ()
64- + = .5 * arena_U.val_op ()
65- * (arena_Fp.array () * (UUadjT - UUadjT.transpose ()).array ())
66- .matrix ()
67- * arena_V.val_op ().transpose ()
68- + (Eigen::MatrixXd::Identity (arena_m.rows (), arena_m.rows ())
69- - arena_U.val_op () * arena_U.val_op ().transpose ())
70- * arena_U.adj_op ()
71- * singular_values.val_op ().asDiagonal ().inverse ()
72- * arena_V.val_op ().transpose ();
63+ auto u_adj
64+ = .5 * arena_U.val_op ()
65+ * (arena_Fp.array () * (UUadjT - UUadjT.transpose ()).array ())
66+ .matrix ()
67+ * arena_V.val_op ().transpose ()
68+ + (Eigen::MatrixXd::Identity (arena_m.rows (), arena_m.rows ())
69+ - arena_U.val_op () * arena_U.val_op ().transpose ())
70+ * arena_U.adj_op ()
71+ * singular_values.val_op ().asDiagonal ().inverse ()
72+ * arena_V.val_op ().transpose ();
7373 // Singular values reverse mode
74- arena_m. adj () + = arena_U.val_op () * singular_values.adj ().asDiagonal ()
75- * arena_V.val_op ().transpose ();
74+ auto d_adj = arena_U.val_op () * singular_values.adj ().asDiagonal ()
75+ * arena_V.val_op ().transpose ();
7676 // SVD-V reverse mode
7777 Eigen::MatrixXd VTVadj = arena_V.val_op ().transpose () * arena_V.adj_op ();
78- arena_m.adj ()
79- += 0.5 * arena_U.val_op ()
80- * (arena_Fm.array () * (VTVadj - VTVadj.transpose ()).array ())
81- .matrix ()
82- * arena_V.val_op ().transpose ()
83- + arena_U.val_op () * singular_values.val_op ().asDiagonal ().inverse ()
84- * arena_V.adj_op ().transpose ()
85- * (Eigen::MatrixXd::Identity (arena_m.cols (), arena_m.cols ())
86- - arena_V.val_op () * arena_V.val_op ().transpose ());
78+ auto v_adj
79+ = 0.5 * arena_U.val_op ()
80+ * (arena_Fm.array () * (VTVadj - VTVadj.transpose ()).array ())
81+ .matrix ()
82+ * arena_V.val_op ().transpose ()
83+ + arena_U.val_op () * singular_values.val_op ().asDiagonal ().inverse ()
84+ * arena_V.adj_op ().transpose ()
85+ * (Eigen::MatrixXd::Identity (arena_m.cols (), arena_m.cols ())
86+ - arena_V.val_op () * arena_V.val_op ().transpose ());
87+
88+ arena_m.adj () += u_adj + d_adj + v_adj;
8789 });
8890
8991 return std::make_tuple (mat_ret_type (arena_U), vec_ret_type (singular_values),
0 commit comments