Skip to content

Commit cf58058

Browse files
committed
Prim for stoch matrices
1 parent 15bed18 commit cf58058

2 files changed

Lines changed: 16 additions & 30 deletions

File tree

stan/math/prim/constraint/stochastic_column_constrain.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ namespace math {
1616
/**
1717
* Return a column stochastic matrix.
1818
*
19-
* The transform is based on a centered stick-breaking process.
19+
* The transform is defined using the inverse of the
20+
* isometric log ratio (ILR) transform
2021
*
2122
* @tparam Mat type of the Matrix
2223
* @param y Free Matrix input of dimensionality (K - 1, M)
@@ -39,8 +40,8 @@ inline plain_type_t<Mat> stochastic_column_constrain(const Mat& y) {
3940
* and increment the specified log probability reference with
4041
* the log absolute Jacobian determinant of the transform.
4142
*
42-
* The simplex transform is defined through a centered
43-
* stick-breaking process.
43+
* The simplex transform is defined using the inverse of the
44+
* isometric log ratio (ILR) transform
4445
*
4546
* @tparam Mat type of the Matrix
4647
* @tparam Lp A scalar type for the lp argument. The scalar type of Mat should

stan/math/prim/constraint/stochastic_row_constrain.hpp

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ namespace math {
1616
/**
1717
* Return a row stochastic matrix.
1818
*
19-
* The transform is based on a centered stick-breaking process.
19+
* The transform is defined using the inverse of the
20+
* isometric log ratio (ILR) transform
2021
*
2122
* @tparam Mat type of the Matrix
2223
* @param y Free Matrix input of dimensionality (N, K - 1).
@@ -27,23 +28,17 @@ template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
2728
inline plain_type_t<Mat> stochastic_row_constrain(const Mat& y) {
2829
auto&& y_ref = to_ref(y);
2930
const Eigen::Index N = y_ref.rows();
30-
int Km1 = y_ref.cols();
31-
plain_type_t<Mat> x(N, Km1 + 1);
32-
using eigen_arr = Eigen::Array<scalar_type_t<Mat>, -1, 1>;
33-
eigen_arr stick_len = eigen_arr::Constant(N, 1.0);
34-
for (Eigen::Index k = 0; k < Km1; ++k) {
35-
auto z_k = inv_logit(y_ref.array().col(k) - log(Km1 - k));
36-
x.array().col(k) = stick_len * z_k;
37-
stick_len -= x.array().col(k);
31+
plain_type_t<Mat> ret(N, y_ref.cols() + 1);
32+
for (Eigen::Index i = 0; i < N; ++i) {
33+
ret.row(i) = simplex_constrain(y_ref.row(i));
3834
}
39-
x.array().col(Km1) = stick_len;
40-
return x;
35+
return ret;
4136
}
4237

4338
/**
4439
* Return a row stochastic matrix.
45-
* The simplex transform is defined through a centered
46-
* stick-breaking process.
40+
* The simplex transform is defined using the inverse of the
41+
* isometric log ratio (ILR) transform
4742
*
4843
* @tparam Mat type of the matrix
4944
* @tparam Lp A scalar type for the lp argument. The scalar type of Mat should
@@ -59,21 +54,11 @@ template <typename Mat, typename Lp,
5954
inline plain_type_t<Mat> stochastic_row_constrain(const Mat& y, Lp& lp) {
6055
auto&& y_ref = to_ref(y);
6156
const Eigen::Index N = y_ref.rows();
62-
Eigen::Index Km1 = y_ref.cols();
63-
plain_type_t<Mat> x(N, Km1 + 1);
64-
Eigen::Array<scalar_type_t<Mat>, -1, 1> stick_len
65-
= Eigen::Array<scalar_type_t<Mat>, -1, 1>::Constant(N, 1.0);
66-
for (Eigen::Index k = 0; k < Km1; ++k) {
67-
const auto eq_share = -log(Km1 - k); // = logit(1.0/(Km1 + 1 - k));
68-
auto adj_y_k = (y_ref.array().col(k) + eq_share).eval();
69-
auto z_k = inv_logit(adj_y_k);
70-
x.array().col(k) = stick_len * z_k;
71-
lp += -sum(log1p_exp(adj_y_k)) - sum(log1p_exp(-adj_y_k))
72-
+ sum(log(stick_len));
73-
stick_len -= x.array().col(k); // equivalently *= (1 - z_k);
57+
plain_type_t<Mat> ret(N, y_ref.cols() + 1);
58+
for (Eigen::Index i = 0; i < N; ++i) {
59+
ret.row(i) = simplex_constrain(y_ref.row(i), lp);
7460
}
75-
x.col(Km1).array() = stick_len;
76-
return x;
61+
return ret;
7762
}
7863

7964
/**

0 commit comments

Comments
 (0)