Skip to content

Commit dde6950

Browse files
committed
Write out derivatives
1 parent 465a209 commit dde6950

2 files changed

Lines changed: 81 additions & 26 deletions

File tree

stan/math/rev/constraint/simplex_constrain.hpp

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
#include <stan/math/rev/meta.hpp>
66
#include <stan/math/rev/core/reverse_pass_callback.hpp>
77
#include <stan/math/rev/core/arena_matrix.hpp>
8-
#include <stan/math/rev/fun/value_of.hpp>
9-
#include <stan/math/prim/constraint/sum_to_zero_constrain.hpp>
10-
#include <stan/math/prim/fun/softmax.hpp>
11-
#include <stan/math/prim/fun/log_softmax.hpp>
8+
#include <stan/math/rev/constraint/sum_to_zero_constrain.hpp>
9+
#include <stan/math/prim/constraint/simplex_constrain.hpp>
1210
#include <cmath>
1311
#include <tuple>
1412
#include <vector>
@@ -32,7 +30,30 @@ namespace math {
3230
*/
3331
template <typename T, require_rev_col_vector_t<T>* = nullptr>
3432
inline auto simplex_constrain(const T& y) {
35-
return softmax(sum_to_zero_constrain(y));
33+
using ret_type = plain_type_t<T>;
34+
35+
const auto N = y.size();
36+
arena_t<T> arena_y = y;
37+
38+
arena_t<ret_type> arena_x = simplex_constrain(arena_y.val());
39+
40+
if (unlikely(N == 0)) {
41+
return ret_type(arena_x);
42+
}
43+
44+
reverse_pass_callback([arena_y, arena_x]() mutable {
45+
const auto& res_val = to_ref(arena_x.val());
46+
47+
Eigen::VectorXd x_pre_softmax_adj = Eigen::VectorXd::Zero(res_val.size());
48+
// backprop for softmax
49+
x_pre_softmax_adj += -res_val * arena_x.adj().dot(res_val)
50+
+ res_val.cwiseProduct(arena_x.adj());
51+
52+
// backprop for sum_to_zero_constrain
53+
internal::sum_to_zero_vector_backprop(arena_y, x_pre_softmax_adj);
54+
});
55+
56+
return ret_type(arena_x);
3657
}
3758

3859
/**
@@ -50,16 +71,36 @@ inline auto simplex_constrain(const T& y) {
5071
* @return Simplex of dimensionality N + 1.
5172
*/
5273
template <typename T, require_rev_col_vector_t<T>* = nullptr>
53-
auto simplex_constrain(const T& y, scalar_type_t<T>& lp) {
74+
inline auto simplex_constrain(const T& y, scalar_type_t<T>& lp) {
5475
using ret_type = plain_type_t<T>;
5576

56-
arena_t<ret_type> log_x = log_softmax(sum_to_zero_constrain(y));
77+
const auto N = y.size();
78+
arena_t<T> arena_y = y;
79+
80+
double lp_val = 0.0;
81+
arena_t<ret_type> arena_x = simplex_constrain(arena_y.val(), lp_val);
82+
lp += lp_val;
83+
84+
if (unlikely(N == 0)) {
85+
return ret_type(arena_x);
86+
}
87+
88+
reverse_pass_callback([arena_y, arena_x, lp]() mutable {
89+
const auto& res_val = to_ref(arena_x.val());
90+
91+
// backprop for log jacobian contribution to log density
92+
arena_x.adj().array() += lp.adj() / res_val.array();
5793

58-
const auto N = y.size() + 1;
94+
Eigen::VectorXd x_pre_softmax_adj = Eigen::VectorXd::Zero(res_val.size());
95+
// backprop for softmax
96+
x_pre_softmax_adj += -res_val * arena_x.adj().dot(res_val)
97+
+ res_val.cwiseProduct(arena_x.adj());
5998

60-
lp += sum(log_x) + 0.5 * log(N);
99+
// backprop for sum_to_zero_constrain
100+
internal::sum_to_zero_vector_backprop(arena_y, x_pre_softmax_adj);
101+
});
61102

62-
return ret_type(exp(log_x));
103+
return ret_type(arena_x);
63104
}
64105

65106
} // namespace math

stan/math/rev/constraint/sum_to_zero_constrain.hpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,35 @@
1515
namespace stan {
1616
namespace math {
1717

18+
namespace internal {
19+
20+
/**
21+
* The reverse pass backprop for the sum_to_zero_constrain on
22+
* vectors. This is separated out so it can also be called by
23+
* simplex_constrain.
24+
*/
25+
template <typename T>
26+
void sum_to_zero_vector_backprop(T&& arena_y, Eigen::VectorXd z_adj) {
27+
const auto N = arena_y.size();
28+
29+
double sum_u_adj = 0;
30+
for (int i = 0; i < N; ++i) {
31+
double n = static_cast<double>(i + 1);
32+
33+
// adjoint of the reverse cumulative sum computed in the forward mode
34+
sum_u_adj += z_adj.coeff(i);
35+
36+
// adjoint of the offset subtraction
37+
double v_adj = -z_adj.coeff(i + 1) * n;
38+
39+
double w_adj = v_adj + sum_u_adj;
40+
41+
arena_y.adj().coeffRef(i) += w_adj / sqrt(n * (n + 1));
42+
}
43+
}
44+
45+
} // namespace internal
46+
1847
/**
1948
* Return a vector with sum zero corresponding to the specified
2049
* free vector.
@@ -49,22 +78,7 @@ inline auto sum_to_zero_constrain(T&& y) {
4978
arena_t<ret_type> arena_z = sum_to_zero_constrain(arena_y.val());
5079

5180
reverse_pass_callback([arena_y, arena_z]() mutable {
52-
const auto N = arena_y.size();
53-
54-
double sum_u_adj = 0;
55-
for (int i = 0; i < N; ++i) {
56-
double n = static_cast<double>(i + 1);
57-
58-
// adjoint of the reverse cumulative sum computed in the forward mode
59-
sum_u_adj += arena_z.adj().coeff(i);
60-
61-
// adjoint of the offset subtraction
62-
double v_adj = -arena_z.adj().coeff(i + 1) * n;
63-
64-
double w_adj = v_adj + sum_u_adj;
65-
66-
arena_y.adj().coeffRef(i) += w_adj / sqrt(n * (n + 1));
67-
}
81+
internal::sum_to_zero_vector_backprop(arena_y, arena_z.adj());
6882
});
6983

7084
return arena_z;

0 commit comments

Comments
 (0)