Skip to content

Commit 37e6003

Browse files
committed
Changes per review comments
1 parent f55acbd commit 37e6003

4 files changed

Lines changed: 22 additions & 19 deletions

File tree

stan/math/prim/constraint/simplex_constrain.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ inline plain_type_t<Vec> simplex_constrain(const Vec& y) {
6565
max_val = fmax(max_val_old, z.coeff(0));
6666
d = d * exp(max_val_old - max_val) + exp(z.coeff(0) - max_val);
6767

68-
for (int i = 0; i <= N; ++i) {
69-
z.coeffRef(i) = exp(z.coeff(i) - max_val) / d;
70-
}
68+
z.array() = (z.array() - max_val).exp() / d;
7169

7270
return z;
7371
}
@@ -128,9 +126,7 @@ inline plain_type_t<Vec> simplex_constrain(const Vec& y, Lp& lp) {
128126
max_val = fmax(max_val_old, z.coeff(0));
129127
d = d * exp(max_val_old - max_val) + exp(z.coeff(0) - max_val);
130128

131-
for (int i = 0; i <= N; ++i) {
132-
z.coeffRef(i) = exp(z.coeff(i) - max_val) / d;
133-
}
129+
z.array() = (z.array() - max_val).exp() / d;
134130

135131
// equivalent to z.log().sum() + 0.5 * log(N + 1)
136132
lp += -(N + 1) * (max_val + log(d)) + 0.5 * log(N + 1);

stan/math/rev/constraint/stochastic_column_constrain.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@ inline plain_type_t<T> stochastic_column_constrain(const T& y) {
4141
reverse_pass_callback([arena_y, arena_x]() mutable {
4242
const auto M = arena_y.cols();
4343

44-
const auto& x_val = to_ref(arena_x.val_op());
45-
const auto& x_adj = to_ref(arena_x.adj_op());
44+
auto&& x_val = arena_x.val_op();
45+
auto&& x_adj = arena_x.adj_op();
4646

47+
Eigen::VectorXd x_pre_softmax_adj(x_val.rows());
4748
for (Eigen::Index i = 0; i < M; ++i) {
4849
// backprop for softmax
49-
Eigen::VectorXd x_pre_softmax_adj
50+
x_pre_softmax_adj.noalias()
5051
= -x_val.col(i) * x_adj.col(i).dot(x_val.col(i))
5152
+ x_val.col(i).cwiseProduct(x_adj.col(i));
5253

@@ -93,16 +94,17 @@ inline plain_type_t<T> stochastic_column_constrain(const T& y,
9394
reverse_pass_callback([arena_y, arena_x, lp]() mutable {
9495
const auto M = arena_y.cols();
9596

96-
const auto& x_val = to_ref(arena_x.val_op());
97+
auto&& x_val = arena_x.val_op();
9798

9899
// backprop for log jacobian contribution to log density
99100
arena_x.adj().array() += lp.adj() / x_val.array();
100101

101-
const auto& x_adj = to_ref(arena_x.adj_op());
102+
auto&& x_adj = arena_x.adj_op();
102103

104+
Eigen::VectorXd x_pre_softmax_adj(x_val.rows());
103105
for (Eigen::Index i = 0; i < M; ++i) {
104106
// backprop for softmax
105-
Eigen::VectorXd x_pre_softmax_adj
107+
x_pre_softmax_adj.noalias()
106108
= -x_val.col(i) * x_adj.col(i).dot(x_val.col(i))
107109
+ x_val.col(i).cwiseProduct(x_adj.col(i));
108110

stan/math/rev/constraint/stochastic_row_constrain.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@ inline auto stochastic_row_constrain(const T& y) {
3939
reverse_pass_callback([arena_y, arena_x]() mutable {
4040
const auto N = arena_y.rows();
4141

42-
const auto& x_val = to_ref(arena_x.val_op());
43-
const auto& x_adj = to_ref(arena_x.adj_op());
42+
auto&& x_val = arena_x.val_op();
43+
auto&& x_adj = arena_x.adj_op();
4444

45+
Eigen::VectorXd x_pre_softmax_adj(x_val.cols());
4546
for (Eigen::Index i = 0; i < N; ++i) {
4647
// backprop for softmax
47-
Eigen::VectorXd x_pre_softmax_adj
48+
x_pre_softmax_adj.noalias()
4849
= -x_val.row(i) * x_adj.row(i).dot(x_val.row(i))
4950
+ x_val.row(i).cwiseProduct(x_adj.row(i));
5051

@@ -91,16 +92,16 @@ inline plain_type_t<T> stochastic_row_constrain(const T& y,
9192
reverse_pass_callback([arena_y, arena_x, lp]() mutable {
9293
const auto N = arena_y.rows();
9394

94-
const auto& x_val = to_ref(arena_x.val_op());
95-
95+
auto&& x_val = arena_x.val_op();
9696
// backprop for log jacobian contribution to log density
9797
arena_x.adj().array() += lp.adj() / x_val.array();
9898

99-
const auto& x_adj = to_ref(arena_x.adj_op());
99+
auto&& x_adj = arena_x.adj_op();
100100

101+
Eigen::VectorXd x_pre_softmax_adj(x_val.cols());
101102
for (Eigen::Index i = 0; i < N; ++i) {
102103
// backprop for softmax
103-
Eigen::VectorXd x_pre_softmax_adj
104+
x_pre_softmax_adj.noalias()
104105
= -x_val.row(i) * x_adj.row(i).dot(x_val.row(i))
105106
+ x_val.row(i).cwiseProduct(x_adj.row(i));
106107

stan/math/rev/constraint/sum_to_zero_constrain.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ namespace internal {
2121
* The reverse pass backprop for the sum_to_zero_constrain on
2222
* vectors. This is separated out so it can also be called by
2323
* simplex_constrain.
24+
*
25+
* @tparam T type of the adjoint vector
26+
* @param y_adj The adjoint of the free vector (size N)
27+
* @param z_adj The adjoint of the zero-sum vector (size N + 1)
2428
*/
2529
template <typename T>
2630
void sum_to_zero_vector_backprop(T&& y_adj, const Eigen::VectorXd& z_adj) {

0 commit comments

Comments
 (0)