Skip to content

Commit 82ac58e

Browse files
committed
Additional cleanups
1 parent a7d2791 commit 82ac58e

17 files changed

Lines changed: 42 additions & 38 deletions

stan/math/rev/fun/beta.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ inline auto beta(const Scalar& a, const VarMat& b) {
180180
- digamma(arena_a + arena_b.val().array()))
181181
* beta_val.array());
182182
return make_callback_var(
183-
beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable {
183+
beta_val, [arena_b, digamma_ab](auto& vi) mutable {
184184
arena_b.adj().array() += vi.adj().array() * digamma_ab.array();
185185
});
186186
}
@@ -210,7 +210,7 @@ inline auto beta(const VarMat& a, const Scalar& b) {
210210
auto digamma_ab = to_arena(digamma(arena_a.val()).array()
211211
- digamma(arena_a.val().array() + arena_b));
212212
return make_callback_var(beta(arena_a.val(), arena_b),
213-
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
213+
[arena_a, digamma_ab](auto& vi) mutable {
214214
arena_a.adj().array() += vi.adj().array()
215215
* digamma_ab
216216
* vi.val().array();

stan/math/rev/fun/lb_constrain.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ inline auto lb_constrain(const T& x, const L& lb, var& lp) {
111111
});
112112
} else {
113113
return make_callback_var(std::exp(value_of(x)) + lb_val,
114-
[lp, arena_lb = var(lb)](auto& vi) mutable {
114+
[arena_lb = var(lb)](auto& vi) mutable {
115115
arena_lb.adj() += vi.adj();
116116
});
117117
}
@@ -211,7 +211,7 @@ inline auto lb_constrain(const T& x, const L& lb, return_type_t<T, L>& lp) {
211211
const auto& x_ref = to_ref(x);
212212
lp += sum(x_ref);
213213
arena_t<ret_type> ret = value_of(x_ref).array().exp() + lb_val;
214-
reverse_pass_callback([ret, lp, arena_lb = var(lb)]() mutable {
214+
reverse_pass_callback([ret, arena_lb = var(lb)]() mutable {
215215
arena_lb.adj() += ret.adj().sum();
216216
});
217217
return ret_type(ret);
@@ -333,7 +333,6 @@ inline auto lb_constrain(const T& x, const L& lb, return_type_t<T, L>& lp) {
333333
auto exp_x = to_arena(arena_x.val().array().exp());
334334
arena_t<ret_type> ret
335335
= (is_not_inf_lb).select(exp_x + lb_val, arena_x.val().array());
336-
auto lp_old = lp;
337336
lp += (is_not_inf_lb).select(arena_x.val(), 0).sum();
338337
reverse_pass_callback([arena_x, ret, exp_x, lp, is_not_inf_lb]() mutable {
339338
const auto lp_adj = lp.adj();

stan/math/rev/fun/lub_constrain.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,28 +313,29 @@ inline auto lub_constrain(const T& x, const L& lb, const U& ub,
313313
} else {
314314
arena_t<T> arena_x = x;
315315
arena_t<L> arena_lb = lb;
316-
const auto x_val = value_of(arena_x).array();
316+
const auto arena_x_val = to_arena(value_of(arena_x).array());
317317
const auto lb_val = value_of(arena_lb).array().eval();
318318
check_less("lub_constrain", "lb", lb_val, ub_val);
319319
auto is_lb_inf = to_arena((lb_val == NEGATIVE_INFTY));
320320
auto diff = to_arena(ub_val - lb_val);
321-
auto neg_abs_x = to_arena(-(value_of(arena_x).array()).abs());
322-
auto inv_logit_x = to_arena(inv_logit(value_of(arena_x).array()));
321+
auto neg_abs_x = to_arena(-arena_x_val.abs());
322+
auto inv_logit_x = to_arena(inv_logit(arena_x_val));
323323
arena_t<ret_type> ret = (is_lb_inf).select(
324-
ub_val - value_of(arena_x).array().exp(), diff * inv_logit_x + lb_val);
324+
ub_val - arena_x_val.exp(), diff * inv_logit_x + lb_val);
325325
lp += (is_lb_inf)
326-
.select(value_of(arena_x).array(),
326+
.select(arena_x_val,
327327
log(diff) + (neg_abs_x - (2.0 * log1p_exp(neg_abs_x))))
328328
.sum();
329-
reverse_pass_callback([arena_x, ub, arena_lb, ret, lp, diff, inv_logit_x,
329+
reverse_pass_callback(
330+
[arena_x, arena_x_val, ub, arena_lb, ret, lp, diff, inv_logit_x,
330331
is_lb_inf]() mutable {
331332
using T_var = arena_t<promote_scalar_t<var, T>>;
332333
using L_var = arena_t<promote_scalar_t<var, L>>;
333334
const auto lp_adj = lp.adj();
334335
if (!is_constant<T>::value) {
335-
const auto x_sign = value_of(arena_x).array().sign().eval();
336+
const auto x_sign = arena_x_val.sign().eval();
336337
forward_as<T_var>(arena_x).adj().array() += (is_lb_inf).select(
337-
ret.adj().array() * -value_of(arena_x).array().exp() + lp_adj,
338+
ret.adj().array() * -arena_x_val.exp() + lp_adj,
338339
ret.adj().array() * diff * inv_logit_x * (1.0 - inv_logit_x)
339340
+ lp.adj() * (1.0 - 2.0 * inv_logit_x));
340341
}

stan/math/rev/fun/read_cov_L.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ inline auto read_cov_L(const T_CPCs& CPCs, const T_sds& sds,
3939
var_value<Eigen::MatrixXd> res
4040
= sds.val().matrix().asDiagonal() * corr_L.val();
4141

42-
reverse_pass_callback([CPCs, sds, corr_L, log_prob, res]() mutable {
42+
reverse_pass_callback([sds, corr_L, log_prob, res]() mutable {
4343
size_t K = sds.size();
4444

4545
corr_L.adj() += sds.val().matrix().asDiagonal() * res.adj();

stan/math/rev/fun/svd.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ inline auto svd(const EigMat& m) {
6161
arena_t<mat_ret_type> arena_V = svd.matrixV();
6262

6363
reverse_pass_callback([arena_m, arena_U, singular_values, arena_V, arena_Fp,
64-
arena_Fm, M]() mutable {
64+
arena_Fm]() mutable {
6565
// SVD-U reverse mode
6666
Eigen::MatrixXd UUadjT = arena_U.val_op().transpose() * arena_U.adj_op();
6767
auto u_adj

stan/math/rev/fun/svd_U.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ inline auto svd_U(const EigMat& m) {
5252
arena_t<ret_type> arena_U = svd.matrixU();
5353
auto arena_V = to_arena(svd.matrixV());
5454

55-
reverse_pass_callback([arena_m, arena_U, arena_D, arena_V, arena_Fp,
56-
M]() mutable {
55+
reverse_pass_callback(
56+
[arena_m, arena_U, arena_D, arena_V, arena_Fp]() mutable {
5757
Eigen::MatrixXd UUadjT = arena_U.val_op().transpose() * arena_U.adj_op();
5858
arena_m.adj()
5959
+= .5 * arena_U.val_op()

stan/math/rev/fun/svd_V.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ inline auto svd_V(const EigMat& m) {
5252
auto arena_U = to_arena(svd.matrixU());
5353
arena_t<ret_type> arena_V = svd.matrixV();
5454

55-
reverse_pass_callback([arena_m, arena_U, arena_D, arena_V, arena_Fm,
56-
M]() mutable {
55+
reverse_pass_callback(
56+
[arena_m, arena_U, arena_D, arena_V, arena_Fm]() mutable {
5757
Eigen::MatrixXd VTVadj = arena_V.val_op().transpose() * arena_V.adj_op();
5858
arena_m.adj()
5959
+= 0.5 * arena_U

stan/math/rev/fun/ub_constrain.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ inline auto ub_constrain(const T& x, const U& ub, return_type_t<T, U>& lp) {
214214
auto x_ref = to_ref(value_of(x));
215215
arena_t<ret_type> ret = ub_val - x_ref.array().exp();
216216
lp += x_ref.sum();
217-
reverse_pass_callback([ret, lp, arena_ub = var(ub)]() mutable {
217+
reverse_pass_callback([ret, arena_ub = var(ub)]() mutable {
218218
arena_ub.adj() += ret.adj().sum();
219219
});
220220
return ret_type(ret);

test/unit/math/mix/fun/abs_test.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ TEST(mixFun, absBasics) {
1717
// test int -> int vectorization
1818
std::vector<int> u{1, 2, 3, 4};
1919
std::vector<int> v = abs(u);
20+
21+
EXPECT_FLOAT_EQ(a, 1);
22+
EXPECT_FLOAT_EQ(b, 2.3);
23+
EXPECT_MATRIX_EQ(x, y);
24+
EXPECT_STD_VECTOR_EQ(u, v);
2025
}
2126

2227
template <typename T1, typename T2>
@@ -231,9 +236,11 @@ TEST(mixFun, absReturnType) {
231236
// validate return types not overpromoted to complex by assignability
232237
std::complex<stan::math::var> a = 3;
233238
stan::math::var b = abs(a);
239+
EXPECT_FLOAT_EQ(b.val(), 3);
234240

235241
std::complex<stan::math::fvar<double>> c = 3;
236242
stan::math::fvar<double> d = abs(c);
243+
EXPECT_FLOAT_EQ(d.val(), 3);
237244
SUCCEED();
238245
}
239246

test/unit/math/mix/fun/eigendecompose_identity_complex_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ TEST(mathMixFun, eigenvectorsIdComplex) {
4040
using fv_t = stan::math::fvar<stan::math::var>;
4141
using ffv_t = stan::math::fvar<fv_t>;
4242

43+
expectComplexEigenvectorsId<d_t>();
4344
expectComplexEigenvectorsId<v_t>();
4445
expectComplexEigenvectorsId<fd_t>();
4546
expectComplexEigenvectorsId<ffd_t>();

0 commit comments

Comments
 (0)