@@ -16,7 +16,8 @@ namespace math {
1616/* *
1717 * Return a row stochastic matrix.
1818 *
19- * The transform is based on a centered stick-breaking process.
19+ * The transform is defined using the inverse of the
20+ * isometric log ratio (ILR) transform
2021 *
2122 * @tparam Mat type of the Matrix
2223 * @param y Free Matrix input of dimensionality (N, K - 1).
@@ -27,23 +28,17 @@ template <typename Mat, require_eigen_matrix_dynamic_t<Mat>* = nullptr,
2728inline plain_type_t <Mat> stochastic_row_constrain (const Mat& y) {
2829 auto && y_ref = to_ref (y);
2930 const Eigen::Index N = y_ref.rows ();
30- int Km1 = y_ref.cols ();
31- plain_type_t <Mat> x (N, Km1 + 1 );
32- using eigen_arr = Eigen::Array<scalar_type_t <Mat>, -1 , 1 >;
33- eigen_arr stick_len = eigen_arr::Constant (N, 1.0 );
34- for (Eigen::Index k = 0 ; k < Km1; ++k) {
35- auto z_k = inv_logit (y_ref.array ().col (k) - log (Km1 - k));
36- x.array ().col (k) = stick_len * z_k;
37- stick_len -= x.array ().col (k);
31+ plain_type_t <Mat> ret (N, y_ref.cols () + 1 );
32+ for (Eigen::Index i = 0 ; i < N; ++i) {
33+ ret.row (i) = simplex_constrain (y_ref.row (i));
3834 }
39- x.array ().col (Km1) = stick_len;
40- return x;
35+ return ret;
4136}
4237
4338/* *
4439 * Return a row stochastic matrix.
45- * The simplex transform is defined through a centered
46- * stick-breaking process.
40+ * The simplex transform is defined using the inverse of the
41+ * isometric log ratio (ILR) transform
4742 *
4843 * @tparam Mat type of the matrix
4944 * @tparam Lp A scalar type for the lp argument. The scalar type of Mat should
@@ -59,21 +54,11 @@ template <typename Mat, typename Lp,
5954inline plain_type_t <Mat> stochastic_row_constrain (const Mat& y, Lp& lp) {
6055 auto && y_ref = to_ref (y);
6156 const Eigen::Index N = y_ref.rows ();
62- Eigen::Index Km1 = y_ref.cols ();
63- plain_type_t <Mat> x (N, Km1 + 1 );
64- Eigen::Array<scalar_type_t <Mat>, -1 , 1 > stick_len
65- = Eigen::Array<scalar_type_t <Mat>, -1 , 1 >::Constant (N, 1.0 );
66- for (Eigen::Index k = 0 ; k < Km1; ++k) {
67- const auto eq_share = -log (Km1 - k); // = logit(1.0/(Km1 + 1 - k));
68- auto adj_y_k = (y_ref.array ().col (k) + eq_share).eval ();
69- auto z_k = inv_logit (adj_y_k);
70- x.array ().col (k) = stick_len * z_k;
71- lp += -sum (log1p_exp (adj_y_k)) - sum (log1p_exp (-adj_y_k))
72- + sum (log (stick_len));
73- stick_len -= x.array ().col (k); // equivalently *= (1 - z_k);
57+ plain_type_t <Mat> ret (N, y_ref.cols () + 1 );
58+ for (Eigen::Index i = 0 ; i < N; ++i) {
59+ ret.row (i) = simplex_constrain (y_ref.row (i), lp);
7460 }
75- x.col (Km1).array () = stick_len;
76- return x;
61+ return ret;
7762}
7863
7964/* *
0 commit comments