Skip to content

Commit f55acbd

Browse files
committed
Rev for stoch matrices
1 parent 36c22c1 commit f55acbd

2 files changed

Lines changed: 125 additions & 150 deletions

File tree

stan/math/rev/constraint/stochastic_column_constrain.hpp

Lines changed: 62 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#ifndef STAN_MATH_REV_CONSTRAINT_STOCHASTIC_COLUMN_CONSTRAIN_HPP
22
#define STAN_MATH_REV_CONSTRAINT_STOCHASTIC_COLUMN_CONSTRAIN_HPP
33

4+
#include <stan/math/prim/fun/Eigen.hpp>
45
#include <stan/math/rev/meta.hpp>
56
#include <stan/math/rev/core/reverse_pass_callback.hpp>
67
#include <stan/math/rev/core/arena_matrix.hpp>
78
#include <stan/math/rev/fun/value_of.hpp>
8-
#include <stan/math/prim/fun/Eigen.hpp>
9-
#include <stan/math/prim/fun/inv_logit.hpp>
10-
#include <stan/math/prim/fun/log1p_exp.hpp>
9+
#include <stan/math/prim/constraint/stochastic_column_constrain.hpp>
10+
#include <stan/math/rev/constraint/sum_to_zero_constrain.hpp>
1111
#include <cmath>
1212
#include <tuple>
1313
#include <vector>
@@ -27,44 +27,36 @@ namespace math {
2727
template <typename T, require_rev_matrix_t<T>* = nullptr>
2828
inline plain_type_t<T> stochastic_column_constrain(const T& y) {
2929
using ret_type = plain_type_t<T>;
30-
const Eigen::Index N = y.rows();
31-
const Eigen::Index M = y.cols();
32-
using eigen_mat_rowmajor
33-
= Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
34-
arena_t<eigen_mat_rowmajor> x_val(N + 1, M);
30+
31+
const auto N = y.rows();
32+
const auto M = y.cols();
33+
arena_t<T> arena_y = y;
34+
35+
arena_t<ret_type> arena_x = stochastic_column_constrain(arena_y.val_op());
36+
3537
if (unlikely(N == 0 || M == 0)) {
36-
return ret_type(x_val);
37-
}
38-
arena_t<change_eigen_options_t<T, Eigen::RowMajor>> arena_y = y;
39-
arena_t<eigen_mat_rowmajor> arena_z(N, M);
40-
using arr_vec = Eigen::Array<double, 1, -1>;
41-
arr_vec stick_len = arr_vec::Constant(M, 1.0);
42-
for (Eigen::Index k = 0; k < N; ++k) {
43-
const double log_N_minus_k = std::log(N - k);
44-
arena_z.row(k)
45-
= inv_logit(arena_y.array().row(k).val_op() - log_N_minus_k).matrix();
46-
x_val.row(k) = stick_len.array() * arena_z.array().row(k);
47-
stick_len -= x_val.array().row(k);
38+
return arena_x;
4839
}
49-
x_val.row(N) = stick_len;
50-
arena_t<ret_type> arena_x = x_val;
51-
reverse_pass_callback([arena_y, arena_x, arena_z]() mutable {
52-
const Eigen::Index N = arena_y.rows();
53-
auto arena_x_arr = arena_x.array();
54-
auto arena_y_arr = arena_y.array();
55-
auto arena_z_arr = arena_z.array();
56-
auto stick_len_val = arena_x.array().row(N).val().eval();
57-
auto stick_len_adj = arena_x.array().row(N).adj().eval();
58-
for (Eigen::Index k = N; k-- > 0;) {
59-
arena_x_arr.row(k).adj() -= stick_len_adj;
60-
stick_len_val += arena_x_arr.row(k).val();
61-
stick_len_adj += arena_x_arr.row(k).adj() * arena_z_arr.row(k);
62-
auto arena_z_adj = arena_x_arr.row(k).adj() * stick_len_val;
63-
arena_y_arr.row(k).adj()
64-
+= arena_z_adj * arena_z_arr.row(k) * (1.0 - arena_z_arr.row(k));
40+
41+
reverse_pass_callback([arena_y, arena_x]() mutable {
42+
const auto M = arena_y.cols();
43+
44+
const auto& x_val = to_ref(arena_x.val_op());
45+
const auto& x_adj = to_ref(arena_x.adj_op());
46+
47+
for (Eigen::Index i = 0; i < M; ++i) {
48+
// backprop for softmax
49+
Eigen::VectorXd x_pre_softmax_adj
50+
= -x_val.col(i) * x_adj.col(i).dot(x_val.col(i))
51+
+ x_val.col(i).cwiseProduct(x_adj.col(i));
52+
53+
// backprop for sum_to_zero_constrain
54+
internal::sum_to_zero_vector_backprop(arena_y.col(i).adj(),
55+
x_pre_softmax_adj);
6556
}
6657
});
67-
return ret_type(arena_x);
58+
59+
return arena_x;
6860
}
6961

7062
/**
@@ -84,51 +76,43 @@ template <typename T, require_rev_matrix_t<T>* = nullptr>
8476
inline plain_type_t<T> stochastic_column_constrain(const T& y,
8577
scalar_type_t<T>& lp) {
8678
using ret_type = plain_type_t<T>;
87-
const Eigen::Index N = y.rows();
88-
const Eigen::Index M = y.cols();
89-
using eigen_mat_rowmajor
90-
= Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
91-
arena_t<eigen_mat_rowmajor> x_val(N + 1, M);
79+
80+
const auto N = y.rows();
81+
const auto M = y.cols();
82+
arena_t<T> arena_y = y;
83+
84+
double lp_val = 0;
85+
arena_t<ret_type> arena_x
86+
= stochastic_column_constrain(arena_y.val_op(), lp_val);
87+
lp += lp_val;
88+
9289
if (unlikely(N == 0 || M == 0)) {
93-
return ret_type(x_val);
90+
return arena_x;
9491
}
95-
arena_t<change_eigen_options_t<T, Eigen::RowMajor>> arena_y = y;
96-
arena_t<eigen_mat_rowmajor> arena_z(N, M);
97-
using arr_vec = Eigen::Array<double, 1, -1>;
98-
arr_vec stick_len = arr_vec::Constant(M, 1.0);
99-
arr_vec adj_y_k(N);
100-
for (Eigen::Index k = 0; k < N; ++k) {
101-
double log_N_minus_k = std::log(N - k);
102-
adj_y_k = arena_y.array().row(k).val() - log_N_minus_k;
103-
arena_z.array().row(k) = inv_logit(adj_y_k);
104-
x_val.array().row(k) = stick_len * arena_z.array().row(k);
105-
lp += sum(log(stick_len)) - sum(log1p_exp(-adj_y_k))
106-
- sum(log1p_exp(adj_y_k));
107-
stick_len -= x_val.array().row(k);
108-
}
109-
x_val.array().row(N) = stick_len;
110-
arena_t<ret_type> arena_x = x_val;
111-
reverse_pass_callback([arena_y, arena_x, arena_z, lp]() mutable {
112-
const Eigen::Index N = arena_y.rows();
113-
auto arena_x_arr = arena_x.array();
114-
auto arena_y_arr = arena_y.array();
115-
auto arena_z_arr = arena_z.array();
116-
auto stick_len_val = arena_x.array().row(N).val().eval();
117-
auto stick_len_adj = arena_x.array().row(N).adj().eval();
118-
for (Eigen::Index k = N; k-- > 0;) {
119-
const double log_N_minus_k = std::log(N - k);
120-
arena_x_arr.row(k).adj() -= stick_len_adj;
121-
stick_len_val += arena_x_arr.row(k).val();
122-
stick_len_adj += lp.adj() / stick_len_val
123-
+ arena_x_arr.row(k).adj() * arena_z_arr.row(k);
124-
auto adj_y_k = arena_y_arr.row(k).val() - log_N_minus_k;
125-
auto arena_z_adj = arena_x_arr.row(k).adj() * stick_len_val;
126-
arena_y_arr.row(k).adj()
127-
+= -(lp.adj() * inv_logit(adj_y_k)) + lp.adj() * inv_logit(-adj_y_k)
128-
+ arena_z_adj * arena_z_arr.row(k) * (1.0 - arena_z_arr.row(k));
92+
93+
reverse_pass_callback([arena_y, arena_x, lp]() mutable {
94+
const auto M = arena_y.cols();
95+
96+
const auto& x_val = to_ref(arena_x.val_op());
97+
98+
// backprop for log jacobian contribution to log density
99+
arena_x.adj().array() += lp.adj() / x_val.array();
100+
101+
const auto& x_adj = to_ref(arena_x.adj_op());
102+
103+
for (Eigen::Index i = 0; i < M; ++i) {
104+
// backprop for softmax
105+
Eigen::VectorXd x_pre_softmax_adj
106+
= -x_val.col(i) * x_adj.col(i).dot(x_val.col(i))
107+
+ x_val.col(i).cwiseProduct(x_adj.col(i));
108+
109+
// backprop for sum_to_zero_constrain
110+
internal::sum_to_zero_vector_backprop(arena_y.col(i).adj(),
111+
x_pre_softmax_adj);
129112
}
130113
});
131-
return ret_type(arena_x);
114+
115+
return arena_x;
132116
}
133117

134118
} // namespace math

stan/math/rev/constraint/stochastic_row_constrain.hpp

Lines changed: 63 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#ifndef STAN_MATH_REV_CONSTRAINT_STOCHASTIC_ROW_CONSTRAIN_HPP
22
#define STAN_MATH_REV_CONSTRAINT_STOCHASTIC_ROW_CONSTRAIN_HPP
33

4+
#include <stan/math/prim/fun/Eigen.hpp>
45
#include <stan/math/rev/meta.hpp>
56
#include <stan/math/rev/core/reverse_pass_callback.hpp>
67
#include <stan/math/rev/core/arena_matrix.hpp>
78
#include <stan/math/rev/fun/value_of.hpp>
8-
#include <stan/math/prim/fun/Eigen.hpp>
9-
#include <stan/math/prim/fun/inv_logit.hpp>
10-
#include <stan/math/prim/fun/log1p_exp.hpp>
9+
#include <stan/math/prim/constraint/stochastic_row_constrain.hpp>
10+
#include <stan/math/rev/constraint/sum_to_zero_constrain.hpp>
1111
#include <cmath>
1212
#include <tuple>
1313
#include <vector>
@@ -23,43 +23,38 @@ namespace math {
2323
* @return Matrix with Simplexes along the rows of dimensionality (N, K)
2424
*/
2525
template <typename T, require_rev_matrix_t<T>* = nullptr>
26-
inline plain_type_t<T> stochastic_row_constrain(const T& y) {
26+
inline auto stochastic_row_constrain(const T& y) {
2727
using ret_type = plain_type_t<T>;
28-
const Eigen::Index N = y.rows();
29-
const Eigen::Index M = y.cols();
30-
arena_t<Eigen::MatrixXd> x_val(N, M + 1);
31-
if (unlikely(N == 0 || M == 0)) {
32-
return ret_type(x_val);
33-
}
28+
29+
const auto N = y.rows();
30+
const auto M = y.cols();
3431
arena_t<T> arena_y = y;
35-
arena_t<Eigen::MatrixXd> arena_z(N, M);
36-
Eigen::Array<double, -1, 1> stick_len = Eigen::Array<double, -1, 1>::Ones(N);
37-
for (Eigen::Index j = 0; j < M; ++j) {
38-
double log_N_minus_k = std::log(M - j);
39-
arena_z.col(j).array()
40-
= inv_logit((arena_y.col(j).val_op().array() - log_N_minus_k).matrix());
41-
x_val.col(j).array() = stick_len * arena_z.col(j).array();
42-
stick_len -= x_val.col(j).array();
32+
33+
arena_t<ret_type> arena_x = stochastic_row_constrain(arena_y.val_op());
34+
35+
if (unlikely(N == 0 || M == 0)) {
36+
return arena_x;
4337
}
44-
x_val.col(M).array() = stick_len;
45-
arena_t<ret_type> arena_x = x_val;
46-
reverse_pass_callback([arena_y, arena_x, arena_z]() mutable {
47-
const Eigen::Index M = arena_y.cols();
48-
auto arena_y_arr = arena_y.array();
49-
auto arena_x_arr = arena_x.array();
50-
auto arena_z_arr = arena_z.array();
51-
auto stick_len_val_arr = arena_x_arr.col(M).val_op().eval();
52-
auto stick_len_adj_arr = arena_x_arr.col(M).adj_op().eval();
53-
for (Eigen::Index k = M; k-- > 0;) {
54-
arena_x_arr.col(k).adj() -= stick_len_adj_arr;
55-
stick_len_val_arr += arena_x_arr.col(k).val_op();
56-
stick_len_adj_arr += arena_x_arr.col(k).adj_op() * arena_z_arr.col(k);
57-
arena_y_arr.col(k).adj() += arena_x_arr.adj_op().col(k)
58-
* stick_len_val_arr * arena_z_arr.col(k)
59-
* (1.0 - arena_z_arr.col(k));
38+
39+
reverse_pass_callback([arena_y, arena_x]() mutable {
40+
const auto N = arena_y.rows();
41+
42+
const auto& x_val = to_ref(arena_x.val_op());
43+
const auto& x_adj = to_ref(arena_x.adj_op());
44+
45+
for (Eigen::Index i = 0; i < N; ++i) {
46+
// backprop for softmax
47+
Eigen::VectorXd x_pre_softmax_adj
48+
= -x_val.row(i) * x_adj.row(i).dot(x_val.row(i))
49+
+ x_val.row(i).cwiseProduct(x_adj.row(i));
50+
51+
// backprop for sum_to_zero_constrain
52+
internal::sum_to_zero_vector_backprop(arena_y.row(i).adj(),
53+
x_pre_softmax_adj);
6054
}
6155
});
62-
return ret_type(arena_x);
56+
57+
return arena_x;
6358
}
6459

6560
/**
@@ -79,47 +74,43 @@ template <typename T, require_rev_matrix_t<T>* = nullptr>
7974
inline plain_type_t<T> stochastic_row_constrain(const T& y,
8075
scalar_type_t<T>& lp) {
8176
using ret_type = plain_type_t<T>;
82-
const Eigen::Index N = y.rows();
83-
const Eigen::Index M = y.cols();
84-
arena_t<Eigen::MatrixXd> x_val(N, M + 1);
85-
if (unlikely(N == 0 || M == 0)) {
86-
return ret_type(x_val);
87-
}
77+
78+
const auto N = y.rows();
79+
const auto M = y.cols();
8880
arena_t<T> arena_y = y;
89-
arena_t<Eigen::MatrixXd> arena_z(N, M);
90-
Eigen::Array<double, -1, 1> stick_len = Eigen::Array<double, -1, 1>::Ones(N);
91-
for (Eigen::Index j = 0; j < M; ++j) {
92-
double log_N_minus_k = std::log(M - j);
93-
auto adj_y_k = arena_y.col(j).val_op().array() - log_N_minus_k;
94-
arena_z.col(j).array() = inv_logit(adj_y_k);
95-
x_val.col(j).array() = stick_len * arena_z.col(j).array();
96-
lp += sum(log(stick_len)) - sum(log1p_exp(-adj_y_k))
97-
- sum(log1p_exp(adj_y_k));
98-
stick_len -= x_val.col(j).array();
81+
82+
double lp_val = 0;
83+
arena_t<ret_type> arena_x
84+
= stochastic_row_constrain(arena_y.val_op(), lp_val);
85+
lp += lp_val;
86+
87+
if (unlikely(N == 0 || M == 0)) {
88+
return arena_x;
9989
}
100-
x_val.col(M).array() = stick_len;
101-
arena_t<ret_type> arena_x = x_val;
102-
reverse_pass_callback([arena_y, arena_x, arena_z, lp]() mutable {
103-
const Eigen::Index M = arena_y.cols();
104-
auto arena_y_arr = arena_y.array();
105-
auto arena_x_arr = arena_x.array();
106-
auto arena_z_arr = arena_z.array();
107-
auto stick_len_val = arena_x_arr.col(M).val_op().eval();
108-
auto stick_len_adj = arena_x_arr.col(M).adj_op().eval();
109-
for (Eigen::Index k = M; k-- > 0;) {
110-
const double log_N_minus_k = std::log(M - k);
111-
arena_x_arr.col(k).adj() -= stick_len_adj;
112-
stick_len_val += arena_x_arr.col(k).val_op();
113-
stick_len_adj += lp.adj() / stick_len_val
114-
+ arena_x_arr.adj_op().col(k) * arena_z_arr.col(k);
115-
auto adj_y_k = arena_y_arr.col(k).val_op() - log_N_minus_k;
116-
arena_y_arr.col(k).adj()
117-
+= -(lp.adj() * inv_logit(adj_y_k)) + lp.adj() * inv_logit(-adj_y_k)
118-
+ arena_x_arr.col(k).adj_op() * stick_len_val * arena_z_arr.col(k)
119-
* (1.0 - arena_z_arr.col(k));
90+
91+
reverse_pass_callback([arena_y, arena_x, lp]() mutable {
92+
const auto N = arena_y.rows();
93+
94+
const auto& x_val = to_ref(arena_x.val_op());
95+
96+
// backprop for log jacobian contribution to log density
97+
arena_x.adj().array() += lp.adj() / x_val.array();
98+
99+
const auto& x_adj = to_ref(arena_x.adj_op());
100+
101+
for (Eigen::Index i = 0; i < N; ++i) {
102+
// backprop for softmax
103+
Eigen::VectorXd x_pre_softmax_adj
104+
= -x_val.row(i) * x_adj.row(i).dot(x_val.row(i))
105+
+ x_val.row(i).cwiseProduct(x_adj.row(i));
106+
107+
// backprop for sum_to_zero_constrain
108+
internal::sum_to_zero_vector_backprop(arena_y.row(i).adj(),
109+
x_pre_softmax_adj);
120110
}
121111
});
122-
return ret_type(arena_x);
112+
113+
return arena_x;
123114
}
124115

125116
} // namespace math

0 commit comments

Comments
 (0)