Skip to content

Commit 15bed18

Browse files
committed
Clean up simplex rev
1 parent e8f3058 commit 15bed18

2 files changed

Lines changed: 10 additions & 12 deletions

File tree

stan/math/rev/constraint/simplex_constrain.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,12 @@ inline auto simplex_constrain(const T& y) {
4444
reverse_pass_callback([arena_y, arena_x]() mutable {
4545
const auto& res_val = to_ref(arena_x.val());
4646

47-
Eigen::VectorXd x_pre_softmax_adj = Eigen::VectorXd::Zero(res_val.size());
4847
// backprop for softmax
49-
x_pre_softmax_adj += -res_val * arena_x.adj().dot(res_val)
50-
+ res_val.cwiseProduct(arena_x.adj());
48+
Eigen::VectorXd x_pre_softmax_adj = -res_val * arena_x.adj().dot(res_val)
49+
+ res_val.cwiseProduct(arena_x.adj());
5150

5251
// backprop for sum_to_zero_constrain
53-
internal::sum_to_zero_vector_backprop(arena_y, x_pre_softmax_adj);
52+
internal::sum_to_zero_vector_backprop(arena_y.adj(), x_pre_softmax_adj);
5453
});
5554

5655
return ret_type(arena_x);
@@ -91,13 +90,12 @@ inline auto simplex_constrain(const T& y, scalar_type_t<T>& lp) {
9190
// backprop for log jacobian contribution to log density
9291
arena_x.adj().array() += lp.adj() / res_val.array();
9392

94-
Eigen::VectorXd x_pre_softmax_adj = Eigen::VectorXd::Zero(res_val.size());
9593
// backprop for softmax
96-
x_pre_softmax_adj += -res_val * arena_x.adj().dot(res_val)
97-
+ res_val.cwiseProduct(arena_x.adj());
94+
Eigen::VectorXd x_pre_softmax_adj = -res_val * arena_x.adj().dot(res_val)
95+
+ res_val.cwiseProduct(arena_x.adj());
9896

9997
// backprop for sum_to_zero_constrain
100-
internal::sum_to_zero_vector_backprop(arena_y, x_pre_softmax_adj);
98+
internal::sum_to_zero_vector_backprop(arena_y.adj(), x_pre_softmax_adj);
10199
});
102100

103101
return ret_type(arena_x);

stan/math/rev/constraint/sum_to_zero_constrain.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ namespace internal {
2323
* simplex_constrain.
2424
*/
2525
template <typename T>
26-
void sum_to_zero_vector_backprop(T&& arena_y, Eigen::VectorXd z_adj) {
27-
const auto N = arena_y.size();
26+
void sum_to_zero_vector_backprop(T&& y_adj, const Eigen::VectorXd& z_adj) {
27+
const auto N = y_adj.size();
2828

2929
double sum_u_adj = 0;
3030
for (int i = 0; i < N; ++i) {
@@ -38,7 +38,7 @@ void sum_to_zero_vector_backprop(T&& arena_y, Eigen::VectorXd z_adj) {
3838

3939
double w_adj = v_adj + sum_u_adj;
4040

41-
arena_y.adj().coeffRef(i) += w_adj / sqrt(n * (n + 1));
41+
y_adj.coeffRef(i) += w_adj / sqrt(n * (n + 1));
4242
}
4343
}
4444

@@ -78,7 +78,7 @@ inline auto sum_to_zero_constrain(T&& y) {
7878
arena_t<ret_type> arena_z = sum_to_zero_constrain(arena_y.val());
7979

8080
reverse_pass_callback([arena_y, arena_z]() mutable {
81-
internal::sum_to_zero_vector_backprop(arena_y, arena_z.adj());
81+
internal::sum_to_zero_vector_backprop(arena_y.adj(), arena_z.adj());
8282
});
8383

8484
return arena_z;

0 commit comments

Comments
 (0)