Skip to content

Commit a845875

Browse files
committed
Merge remote-tracking branch 'origin/develop' into feature/threadsafe-matrixcl
2 parents b5fbe4a + 34881d4 commit a845875

15 files changed

Lines changed: 277 additions & 76 deletions

stan/math/prim/fun/promote_scalar.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@ inline auto promote_scalar(UnPromotedType&& x) {
5959

6060
// Forward decl for iterating over tuples used in std::vector<tuple>
6161
template <typename PromotionScalars, typename UnPromotedTypes,
62-
require_all_tuple_t<PromotionScalars, UnPromotedTypes>* = nullptr>
62+
require_all_tuple_t<PromotionScalars, UnPromotedTypes>* = nullptr,
63+
require_not_same_t<PromotionScalars, UnPromotedTypes>* = nullptr>
6364
inline constexpr promote_scalar_t<PromotionScalars, UnPromotedTypes>
6465
promote_scalar(UnPromotedTypes&& x);
66+
6567
/**
6668
* Promote the scalar type of an standard vector to the requested type.
6769
*
@@ -93,7 +95,8 @@ inline auto promote_scalar(UnPromotedType&& x) {
9395
* @param x input
9496
*/
9597
template <typename PromotionScalars, typename UnPromotedTypes,
96-
require_all_tuple_t<PromotionScalars, UnPromotedTypes>*>
98+
require_all_tuple_t<PromotionScalars, UnPromotedTypes>*,
99+
require_not_same_t<PromotionScalars, UnPromotedTypes>*>
97100
inline constexpr promote_scalar_t<PromotionScalars, UnPromotedTypes>
98101
promote_scalar(UnPromotedTypes&& x) {
99102
return index_apply<std::tuple_size<std::decay_t<UnPromotedTypes>>::value>(

stan/math/prim/fun/stan_print.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ void stan_print(std::ostream* o, const EigMat& x) {
5252
*o << ']';
5353
}
5454

55+
// forward decl to allow the next two overloads to call each other
56+
template <typename T, require_tuple_t<T>* = nullptr>
57+
void stan_print(std::ostream* o, const T& x);
58+
5559
template <typename T, require_std_vector_t<T>* = nullptr>
5660
void stan_print(std::ostream* o, const T& x) {
5761
*o << '[';
@@ -64,7 +68,7 @@ void stan_print(std::ostream* o, const T& x) {
6468
*o << ']';
6569
}
6670

67-
template <typename T, require_tuple_t<T>* = nullptr>
71+
template <typename T, require_tuple_t<T>*>
6872
void stan_print(std::ostream* o, const T& x) {
6973
*o << '(';
7074
constexpr auto tuple_size = std::tuple_size<std::decay_t<T>>::value;

stan/math/prim/prob/bernoulli_logit_glm_lpmf.hpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
#include <stan/math/prim/fun/as_array_or_scalar.hpp>
99
#include <stan/math/prim/fun/constants.hpp>
1010
#include <stan/math/prim/fun/exp.hpp>
11+
#include <stan/math/prim/fun/isfinite.hpp>
1112
#include <stan/math/prim/fun/size.hpp>
1213
#include <stan/math/prim/fun/size_zero.hpp>
1314
#include <stan/math/prim/fun/to_ref.hpp>
14-
#include <stan/math/prim/fun/value_of_rec.hpp>
15+
#include <stan/math/prim/fun/value_of.hpp>
1516
#include <stan/math/prim/functor/partials_propagator.hpp>
1617
#include <cmath>
1718

@@ -54,11 +55,16 @@ return_type_t<T_x, T_alpha, T_beta> bernoulli_logit_glm_lpmf(
5455
using Eigen::log1p;
5556
using Eigen::Matrix;
5657
using std::exp;
58+
using std::isfinite;
5759
constexpr int T_x_rows = T_x::RowsAtCompileTime;
60+
using T_xbeta_partials = partials_return_t<T_x, T_beta>;
5861
using T_partials_return = partials_return_t<T_y, T_x, T_alpha, T_beta>;
5962
using T_ytheta_tmp =
6063
typename std::conditional_t<T_x_rows == 1, T_partials_return,
6164
Array<T_partials_return, Dynamic, 1>>;
65+
using T_xbeta_tmp =
66+
typename std::conditional_t<T_x_rows == 1, T_xbeta_partials,
67+
Array<T_xbeta_partials, Dynamic, 1>>;
6268
using T_x_ref = ref_type_if_t<!is_constant<T_x>::value, T_x>;
6369
using T_alpha_ref = ref_type_if_t<!is_constant<T_alpha>::value, T_alpha>;
6470
using T_beta_ref = ref_type_if_t<!is_constant<T_beta>::value, T_beta>;
@@ -86,11 +92,10 @@ return_type_t<T_x, T_alpha, T_beta> bernoulli_logit_glm_lpmf(
8692
T_alpha_ref alpha_ref = alpha;
8793
T_beta_ref beta_ref = beta;
8894

89-
const auto& y_val = value_of_rec(y_ref);
90-
const auto& x_val
91-
= to_ref_if<!is_constant<T_beta>::value>(value_of_rec(x_ref));
92-
const auto& alpha_val = value_of_rec(alpha_ref);
93-
const auto& beta_val = value_of_rec(beta_ref);
95+
const auto& y_val = value_of(y_ref);
96+
const auto& x_val = to_ref_if<!is_constant<T_beta>::value>(value_of(x_ref));
97+
const auto& alpha_val = value_of(alpha_ref);
98+
const auto& beta_val = value_of(beta_ref);
9499

95100
const auto& y_val_vec = as_column_vector_or_scalar(y_val);
96101
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val);
@@ -103,7 +108,7 @@ return_type_t<T_x, T_alpha, T_beta> bernoulli_logit_glm_lpmf(
103108
Array<T_partials_return, Dynamic, 1> ytheta(N_instances);
104109
if (T_x_rows == 1) {
105110
T_ytheta_tmp ytheta_tmp
106-
= forward_as<T_ytheta_tmp>((x_val * beta_val_vec)(0, 0));
111+
= forward_as<T_xbeta_tmp>((x_val * beta_val_vec)(0, 0));
107112
ytheta = signs * (ytheta_tmp + as_array_or_scalar(alpha_val_vec));
108113
} else {
109114
ytheta = (x_val * beta_val_vec).array();
@@ -120,7 +125,7 @@ return_type_t<T_x, T_alpha, T_beta> bernoulli_logit_glm_lpmf(
120125
.select(-exp_m_ytheta,
121126
(ytheta < -cutoff).select(ytheta, -log1p(exp_m_ytheta))));
122127

123-
if (!std::isfinite(logp)) {
128+
if (!isfinite(logp)) {
124129
check_finite(function, "Weight vector", beta);
125130
check_finite(function, "Intercept", alpha);
126131
check_finite(function, "Matrix of independent variables", ytheta);
@@ -133,7 +138,7 @@ return_type_t<T_x, T_alpha, T_beta> bernoulli_logit_glm_lpmf(
133138
= (ytheta > cutoff)
134139
.select(-exp_m_ytheta,
135140
(ytheta < -cutoff)
136-
.select(signs,
141+
.select(signs * T_partials_return(1.0),
137142
signs * exp_m_ytheta / (exp_m_ytheta + 1)));
138143
if (!is_constant_all<T_beta>::value) {
139144
if (T_x_rows == 1) {

stan/math/prim/prob/categorical_logit_glm_lpmf.hpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
77
#include <stan/math/prim/fun/as_array_or_scalar.hpp>
88
#include <stan/math/prim/fun/exp.hpp>
9+
#include <stan/math/prim/fun/isfinite.hpp>
910
#include <stan/math/prim/fun/log.hpp>
1011
#include <stan/math/prim/fun/scalar_seq_view.hpp>
1112
#include <stan/math/prim/fun/size.hpp>
1213
#include <stan/math/prim/fun/size_zero.hpp>
1314
#include <stan/math/prim/fun/to_ref.hpp>
14-
#include <stan/math/prim/fun/value_of_rec.hpp>
15+
#include <stan/math/prim/fun/value_of.hpp>
1516
#include <stan/math/prim/functor/partials_propagator.hpp>
1617
#include <stan/math/prim/fun/Eigen.hpp>
1718
#include <cmath>
@@ -50,11 +51,13 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
5051
using Eigen::Dynamic;
5152
using Eigen::Matrix;
5253
using std::exp;
54+
using std::isfinite;
5355
using std::log;
5456
using T_y_ref = ref_type_t<T_y>;
5557
using T_x_ref = ref_type_if_t<!is_constant<T_x>::value, T_x>;
5658
using T_alpha_ref = ref_type_if_t<!is_constant<T_alpha>::value, T_alpha>;
5759
using T_beta_ref = ref_type_if_t<!is_constant<T_beta>::value, T_beta>;
60+
using T_beta_partials = partials_type_t<scalar_type_t<T_beta>>;
5861
constexpr int T_x_rows = T_x::RowsAtCompileTime;
5962

6063
const size_t N_instances = T_x_rows == 1 ? stan::math::size(y) : x.rows();
@@ -82,11 +85,10 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
8285
T_alpha_ref alpha_ref = alpha;
8386
T_beta_ref beta_ref = beta;
8487

85-
const auto& x_val
86-
= to_ref_if<!is_constant<T_beta>::value>(value_of_rec(x_ref));
87-
const auto& alpha_val = value_of_rec(alpha_ref);
88+
const auto& x_val = to_ref_if<!is_constant<T_beta>::value>(value_of(x_ref));
89+
const auto& alpha_val = value_of(alpha_ref);
8890
const auto& beta_val
89-
= to_ref_if<!is_constant<T_x>::value>(value_of_rec(beta_ref));
91+
= to_ref_if<!is_constant<T_x>::value>(value_of(beta_ref));
9092

9193
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val).transpose();
9294

@@ -117,7 +119,7 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
117119
// when we have newer Eigen T_partials_return logp =
118120
// lin(Eigen::all,y-1).sum() + log(inv_sum_exp_lin).sum() - lin_max.sum();
119121

120-
if (!std::isfinite(logp)) {
122+
if (!isfinite(logp)) {
121123
check_finite(function, "Weight vector", beta_ref);
122124
check_finite(function, "Intercept", alpha_ref);
123125
check_finite(function, "Matrix of independent variables", x_ref);
@@ -128,7 +130,7 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
128130

129131
if (!is_constant_all<T_x>::value) {
130132
if (T_x_rows == 1) {
131-
Array<double, 1, Dynamic> beta_y = beta_val.col(y_seq[0] - 1);
133+
Array<T_beta_partials, 1, Dynamic> beta_y = beta_val.col(y_seq[0] - 1);
132134
for (int i = 1; i < N_instances; i++) {
133135
beta_y += beta_val.col(y_seq[i] - 1).array();
134136
}
@@ -137,7 +139,8 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
137139
- (exp_lin.matrix() * beta_val.transpose()).array().colwise()
138140
* inv_sum_exp_lin * N_instances;
139141
} else {
140-
Array<double, Dynamic, Dynamic> beta_y(N_instances, N_attributes);
142+
Array<T_beta_partials, Dynamic, Dynamic> beta_y(N_instances,
143+
N_attributes);
141144
for (int i = 0; i < N_instances; i++) {
142145
beta_y.row(i) = beta_val.col(y_seq[i] - 1);
143146
}
@@ -166,12 +169,11 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
166169
}
167170
}
168171
if (!is_constant_all<T_beta>::value) {
169-
Matrix<T_partials_return, Dynamic, Dynamic> beta_derivative;
172+
Matrix<T_partials_return, Dynamic, Dynamic> beta_derivative
173+
= x_val.transpose().template cast<T_partials_return>()
174+
* neg_softmax_lin.matrix();
170175
if (T_x_rows == 1) {
171-
beta_derivative
172-
= x_val.transpose() * neg_softmax_lin.matrix() * N_instances;
173-
} else {
174-
beta_derivative = x_val.transpose() * neg_softmax_lin.matrix();
176+
beta_derivative *= N_instances;
175177
}
176178

177179
for (int i = 0; i < N_instances; i++) {

stan/math/prim/prob/neg_binomial_2_log_glm_lpmf.hpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include <stan/math/prim/fun/size.hpp>
1616
#include <stan/math/prim/fun/sum.hpp>
1717
#include <stan/math/prim/fun/to_ref.hpp>
18-
#include <stan/math/prim/fun/value_of_rec.hpp>
18+
#include <stan/math/prim/fun/value_of.hpp>
1919
#include <stan/math/prim/functor/partials_propagator.hpp>
2020
#include <vector>
2121
#include <cmath>
@@ -72,6 +72,7 @@ return_type_t<T_x, T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
7272
using Eigen::log1p;
7373
using Eigen::Matrix;
7474
constexpr int T_x_rows = T_x::RowsAtCompileTime;
75+
using T_xbeta_partials = partials_return_t<T_x, T_beta>;
7576
using T_partials_return
7677
= partials_return_t<T_y, T_x, T_alpha, T_beta, T_precision>;
7778
using T_precision_val = typename std::conditional_t<
@@ -85,6 +86,9 @@ return_type_t<T_x, T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
8586
using T_theta_tmp =
8687
typename std::conditional_t<T_x_rows == 1, T_partials_return,
8788
Array<T_partials_return, Dynamic, 1>>;
89+
using T_xbeta_tmp =
90+
typename std::conditional_t<T_x_rows == 1, T_xbeta_partials,
91+
Array<T_xbeta_partials, Dynamic, 1>>;
8892
using T_x_ref = ref_type_if_t<!is_constant<T_x>::value, T_x>;
8993
using T_alpha_ref = ref_type_if_t<!is_constant<T_alpha>::value, T_alpha>;
9094
using T_beta_ref = ref_type_if_t<!is_constant<T_beta>::value, T_beta>;
@@ -103,8 +107,8 @@ return_type_t<T_x, T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
103107
check_consistent_size(function, "Vector of intercepts", alpha, N_instances);
104108
T_alpha_ref alpha_ref = alpha;
105109
T_beta_ref beta_ref = beta;
106-
const auto& beta_val = value_of_rec(beta_ref);
107-
const auto& alpha_val = value_of_rec(alpha_ref);
110+
const auto& beta_val = value_of(beta_ref);
111+
const auto& alpha_val = value_of(alpha_ref);
108112
const auto& beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
109113
const auto& alpha_val_vec = to_ref(as_column_vector_or_scalar(alpha_val));
110114
check_finite(function, "Weight vector", beta_val_vec);
@@ -117,8 +121,8 @@ return_type_t<T_x, T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
117121
const auto& y_ref = to_ref(y);
118122
T_phi_ref phi_ref = phi;
119123

120-
const auto& y_val = value_of_rec(y_ref);
121-
const auto& phi_val = value_of_rec(phi_ref);
124+
const auto& y_val = value_of(y_ref);
125+
const auto& phi_val = value_of(phi_ref);
122126

123127
const auto& y_val_vec = to_ref(as_column_vector_or_scalar(y_val));
124128
const auto& phi_val_vec = to_ref(as_column_vector_or_scalar(phi_val));
@@ -131,16 +135,15 @@ return_type_t<T_x, T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
131135

132136
T_x_ref x_ref = x;
133137

134-
const auto& x_val
135-
= to_ref_if<!is_constant<T_beta>::value>(value_of_rec(x_ref));
138+
const auto& x_val = to_ref_if<!is_constant<T_beta>::value>(value_of(x_ref));
136139

137140
const auto& y_arr = as_array_or_scalar(y_val_vec);
138141
const auto& phi_arr = as_array_or_scalar(phi_val_vec);
139142

140143
Array<T_partials_return, Dynamic, 1> theta(N_instances);
141144
if (T_x_rows == 1) {
142145
T_theta_tmp theta_tmp
143-
= forward_as<T_theta_tmp>((x_val * beta_val_vec)(0, 0));
146+
= forward_as<T_xbeta_tmp>((x_val * beta_val_vec)(0, 0));
144147
theta = theta_tmp + as_array_or_scalar(alpha_val_vec);
145148
} else {
146149
theta = (x_val * beta_val_vec).array();
@@ -171,10 +174,11 @@ return_type_t<T_x, T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
171174
logp += multiply_log(phi_vec[n], phi_vec[n]) - lgamma(phi_vec[n]);
172175
}
173176
} else {
177+
using T_phi_scalar = scalar_type_t<decltype(phi_val_vec)>;
174178
logp += N_instances
175-
* (multiply_log(forward_as<double>(phi_val),
176-
forward_as<double>(phi_val))
177-
- lgamma(forward_as<double>(phi_val)));
179+
* (multiply_log(forward_as<T_phi_scalar>(phi_val),
180+
forward_as<T_phi_scalar>(phi_val))
181+
- lgamma(forward_as<T_phi_scalar>(phi_val)));
178182
}
179183
}
180184
logp -= sum(y_plus_phi * logsumexp_theta_logphi);

stan/math/prim/prob/normal_id_glm_lpdf.hpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
77
#include <stan/math/prim/fun/as_array_or_scalar.hpp>
88
#include <stan/math/prim/fun/constants.hpp>
9+
#include <stan/math/prim/fun/isfinite.hpp>
910
#include <stan/math/prim/fun/log.hpp>
1011
#include <stan/math/prim/fun/size.hpp>
1112
#include <stan/math/prim/fun/size_zero.hpp>
1213
#include <stan/math/prim/fun/sum.hpp>
1314
#include <stan/math/prim/fun/to_ref.hpp>
14-
#include <stan/math/prim/fun/value_of_rec.hpp>
15+
#include <stan/math/prim/fun/value_of.hpp>
1516
#include <stan/math/prim/functor/partials_propagator.hpp>
1617
#include <cmath>
1718

@@ -59,6 +60,7 @@ return_type_t<T_y, T_x, T_alpha, T_beta, T_scale> normal_id_glm_lpdf(
5960
using Eigen::Dynamic;
6061
using Eigen::Matrix;
6162
using Eigen::VectorXd;
63+
using std::isfinite;
6264
constexpr int T_x_rows = T_x::RowsAtCompileTime;
6365
using T_partials_return
6466
= partials_return_t<T_y, T_x, T_alpha, T_beta, T_scale>;
@@ -86,7 +88,7 @@ return_type_t<T_y, T_x, T_alpha, T_beta, T_scale> normal_id_glm_lpdf(
8688
N_instances);
8789
check_consistent_size(function, "Vector of intercepts", alpha, N_instances);
8890
T_sigma_ref sigma_ref = sigma;
89-
const auto& sigma_val = value_of_rec(sigma_ref);
91+
const auto& sigma_val = value_of(sigma_ref);
9092
const auto& sigma_val_vec = to_ref(as_column_vector_or_scalar(sigma_val));
9193
check_positive_finite(function, "Scale vector", sigma_val_vec);
9294

@@ -102,11 +104,10 @@ return_type_t<T_y, T_x, T_alpha, T_beta, T_scale> normal_id_glm_lpdf(
102104
T_alpha_ref alpha_ref = alpha;
103105
T_beta_ref beta_ref = beta;
104106

105-
const auto& y_val = value_of_rec(y_ref);
106-
const auto& x_val
107-
= to_ref_if<!is_constant<T_beta>::value>(value_of_rec(x_ref));
108-
const auto& alpha_val = value_of_rec(alpha_ref);
109-
const auto& beta_val = value_of_rec(beta_ref);
107+
const auto& y_val = value_of(y_ref);
108+
const auto& x_val = to_ref_if<!is_constant<T_beta>::value>(value_of(x_ref));
109+
const auto& alpha_val = value_of(alpha_ref);
110+
const auto& beta_val = value_of(beta_ref);
110111

111112
const auto& y_val_vec = as_column_vector_or_scalar(y_val);
112113
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val);
@@ -116,7 +117,7 @@ return_type_t<T_y, T_x, T_alpha, T_beta, T_scale> normal_id_glm_lpdf(
116117
T_scale_val inv_sigma = 1.0 / as_array_or_scalar(sigma_val_vec);
117118

118119
// the most efficient way to calculate this depends on template parameters
119-
double y_scaled_sq_sum;
120+
T_partials_return y_scaled_sq_sum;
120121

121122
Array<T_partials_return, Dynamic, 1> y_scaled(N_instances);
122123
if (T_x_rows == 1) {
@@ -178,7 +179,8 @@ return_type_t<T_y, T_x, T_alpha, T_beta, T_scale> normal_id_glm_lpdf(
178179
} else {
179180
y_scaled_sq_sum = sum(y_scaled * y_scaled);
180181
partials<4>(ops_partials)[0]
181-
= (y_scaled_sq_sum - N_instances) * forward_as<double>(inv_sigma);
182+
= (y_scaled_sq_sum - N_instances)
183+
* forward_as<partials_return_t<T_sigma_ref>>(inv_sigma);
182184
}
183185
} else {
184186
y_scaled_sq_sum = sum(y_scaled * y_scaled);
@@ -187,7 +189,7 @@ return_type_t<T_y, T_x, T_alpha, T_beta, T_scale> normal_id_glm_lpdf(
187189
y_scaled_sq_sum = sum(y_scaled * y_scaled);
188190
}
189191

190-
if (!std::isfinite(y_scaled_sq_sum)) {
192+
if (!isfinite(y_scaled_sq_sum)) {
191193
check_finite(function, "Vector of dependent variables", y_val_vec);
192194
check_finite(function, "Weight vector", beta_val_vec);
193195
check_finite(function, "Intercept", alpha_val_vec);
@@ -204,7 +206,8 @@ return_type_t<T_y, T_x, T_alpha, T_beta, T_scale> normal_id_glm_lpdf(
204206
if (is_vector<T_scale>::value) {
205207
logp -= sum(log(sigma_val_vec));
206208
} else {
207-
logp -= N_instances * log(forward_as<double>(sigma_val_vec));
209+
logp -= N_instances
210+
* log(forward_as<partials_return_t<T_sigma_ref>>(sigma_val_vec));
208211
}
209212
}
210213
logp -= 0.5 * y_scaled_sq_sum;

0 commit comments

Comments
 (0)