Skip to content

Commit a4b384c

Browse files
committed
Merge branch 'develop' into issue-2783-bernoulli-cdf-stable
2 parents ccde5d5 + f4c6817 commit a4b384c

16 files changed

Lines changed: 722 additions & 260 deletions

File tree

make/tests

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ $(IDAS_TESTS) : $(LIBSUNDIALS)
8282
# KINSOL tests
8383
##
8484

85-
ALGEBRA_SOLVER_TESTS := $(subst .cpp,$(EXE),$(call findfiles,test,*algebra_solver*_test.cpp))
85+
ALGEBRA_SOLVER_TESTS := $(subst .cpp,$(EXE),$(call findfiles,test,*solve*_test.cpp))
8686
$(ALGEBRA_SOLVER_TESTS) : $(LIBSUNDIALS)
8787

8888
### These can be generated by the jumbo tests and include the above kinds

stan/math/prim/fun/atan2.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/prim/core.hpp>
55
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
67
#include <cmath>
78

89
namespace stan {
@@ -23,6 +24,23 @@ double atan2(T1 y, T2 x) {
2324
return std::atan2(y, x);
2425
}
2526

27+
/**
28+
* Enables the vectorised application of the atan2 function, when
29+
* the first and/or second arguments are containers.
30+
*
31+
* @tparam T1 type of first input
32+
* @tparam T2 type of second input
33+
* @param a First input
34+
* @param b Second input
35+
* @return Returns the atan2 function applied to the two inputs.
36+
*/
37+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
38+
require_all_not_var_matrix_t<T1, T2>* = nullptr>
39+
inline auto atan2(const T1& a, const T2& b) {
40+
return apply_scalar_binary(
41+
a, b, [](const auto& c, const auto& d) { return atan2(c, d); });
42+
}
43+
2644
} // namespace math
2745
} // namespace stan
2846

stan/math/prim/fun/minus.hpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define STAN_MATH_PRIM_FUN_MINUS_HPP
33

44
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/functor/apply_vector_unary.hpp>
56

67
namespace stan {
78
namespace math {
@@ -13,11 +14,24 @@ namespace math {
1314
* @param x Subtrahend.
1415
* @return Negation of subtrahend.
1516
*/
16-
template <typename T>
17+
template <typename T, require_not_std_vector_t<T>* = nullptr>
1718
inline auto minus(const T& x) {
1819
return -x;
1920
}
2021

22+
/**
23+
* Return the negation of the each element of a vector
24+
*
25+
* @tparam T Type of container.
26+
* @param x Container.
27+
* @return Container where each element is negated.
28+
*/
29+
template <typename T>
30+
inline auto minus(const std::vector<T>& x) {
31+
return apply_vector_unary<std::vector<T>>::apply(
32+
x, [](const auto& v) { return -v; });
33+
}
34+
2135
} // namespace math
2236
} // namespace stan
2337

stan/math/prim/fun/stan_print.hpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/Eigen.hpp>
6+
#include <stan/math/prim/functor/for_each.hpp>
67
#include <vector>
78

89
namespace stan {
910
namespace math {
1011
// prints used in generator for print() statements in modeling language
1112

12-
template <typename T, require_not_container_t<T>* = nullptr>
13+
template <typename T, require_not_container_t<T>* = nullptr,
14+
require_not_tuple_t<T>* = nullptr>
1315
void stan_print(std::ostream* o, const T& x) {
1416
*o << x;
1517
}
@@ -50,8 +52,8 @@ void stan_print(std::ostream* o, const EigMat& x) {
5052
*o << ']';
5153
}
5254

53-
template <typename T>
54-
void stan_print(std::ostream* o, const std::vector<T>& x) {
55+
template <typename T, require_std_vector_t<T>* = nullptr>
56+
void stan_print(std::ostream* o, const T& x) {
5557
*o << '[';
5658
for (size_t i = 0; i < x.size(); ++i) {
5759
if (i > 0) {
@@ -62,6 +64,23 @@ void stan_print(std::ostream* o, const std::vector<T>& x) {
6264
*o << ']';
6365
}
6466

67+
template <typename T, require_tuple_t<T>* = nullptr>
68+
void stan_print(std::ostream* o, const T& x) {
69+
*o << '(';
70+
constexpr auto tuple_size = std::tuple_size<std::decay_t<T>>::value;
71+
size_t i = 0;
72+
stan::math::for_each(
73+
[&i, o](auto&& elt) {
74+
if (i > 0) {
75+
*o << ',';
76+
}
77+
stan_print(o, elt);
78+
i++;
79+
},
80+
x);
81+
*o << ')';
82+
}
83+
6584
} // namespace math
6685
} // namespace stan
6786
#endif

stan/math/rev/fun/atan2.hpp

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ inline var atan2(const var& a, const var& b) {
2929
std::atan2(a.val(), b.val()), [a, b](const auto& vi) mutable {
3030
double a_sq_plus_b_sq = (a.val() * a.val()) + (b.val() * b.val());
3131
a.adj() += vi.adj_ * b.val() / a_sq_plus_b_sq;
32-
b.adj() -= vi.adj_ * a.val() / a_sq_plus_b_sq;
32+
b.adj() += -vi.adj_ * a.val() / a_sq_plus_b_sq;
3333
});
3434
}
3535

@@ -93,10 +93,150 @@ inline var atan2(double a, const var& b) {
9393
return make_callback_var(
9494
std::atan2(a, b.val()), [a, b](const auto& vi) mutable {
9595
double a_sq_plus_b_sq = (a * a) + (b.val() * b.val());
96-
b.adj() -= vi.adj_ * a / a_sq_plus_b_sq;
96+
b.adj() += -vi.adj_ * a / a_sq_plus_b_sq;
9797
});
9898
}
9999

100+
template <typename Mat1, typename Mat2,
101+
require_any_var_matrix_t<Mat1, Mat2>* = nullptr,
102+
require_all_matrix_t<Mat1, Mat2>* = nullptr>
103+
inline auto atan2(const Mat1& a, const Mat2& b) {
104+
if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
105+
arena_t<promote_scalar_t<var, Mat1>> arena_a = a;
106+
arena_t<promote_scalar_t<var, Mat2>> arena_b = b;
107+
auto atan2_val = atan2(arena_a.val(), arena_b.val());
108+
auto a_sq_plus_b_sq
109+
= to_arena((arena_a.val().array() * arena_a.val().array())
110+
+ (arena_b.val().array() * arena_b.val().array()));
111+
return make_callback_var(
112+
atan2(arena_a.val(), arena_b.val()),
113+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
114+
arena_a.adj().array()
115+
+= vi.adj().array() * arena_b.val().array() / a_sq_plus_b_sq;
116+
arena_b.adj().array()
117+
+= -vi.adj().array() * arena_a.val().array() / a_sq_plus_b_sq;
118+
});
119+
} else if (!is_constant<Mat1>::value) {
120+
arena_t<promote_scalar_t<var, Mat1>> arena_a = a;
121+
arena_t<promote_scalar_t<double, Mat2>> arena_b = value_of(b);
122+
auto a_sq_plus_b_sq
123+
= to_arena((arena_a.val().array() * arena_a.val().array())
124+
+ (arena_b.array() * arena_b.array()));
125+
126+
return make_callback_var(
127+
atan2(arena_a.val(), arena_b),
128+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
129+
arena_a.adj().array()
130+
+= vi.adj().array() * arena_b.array() / a_sq_plus_b_sq;
131+
});
132+
} else if (!is_constant<Mat2>::value) {
133+
arena_t<promote_scalar_t<double, Mat1>> arena_a = value_of(a);
134+
arena_t<promote_scalar_t<var, Mat2>> arena_b = b;
135+
auto a_sq_plus_b_sq
136+
= to_arena((arena_a.array() * arena_a.array())
137+
+ (arena_b.val().array() * arena_b.val().array()));
138+
139+
return make_callback_var(
140+
atan2(arena_a, arena_b.val()),
141+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
142+
arena_b.adj().array()
143+
+= -vi.adj().array() * arena_a.array() / a_sq_plus_b_sq;
144+
});
145+
}
146+
}
147+
148+
template <typename Scalar, typename VarMat,
149+
require_var_matrix_t<VarMat>* = nullptr,
150+
require_stan_scalar_t<Scalar>* = nullptr>
151+
inline auto atan2(const Scalar& a, const VarMat& b) {
152+
if (!is_constant<Scalar>::value && !is_constant<VarMat>::value) {
153+
var arena_a = a;
154+
arena_t<promote_scalar_t<var, VarMat>> arena_b = b;
155+
auto atan2_val = atan2(arena_a.val(), arena_b.val());
156+
auto a_sq_plus_b_sq
157+
= to_arena((arena_a.val() * arena_a.val())
158+
+ (arena_b.val().array() * arena_b.val().array()));
159+
return make_callback_var(
160+
atan2(arena_a.val(), arena_b.val()),
161+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
162+
arena_a.adj()
163+
+= (vi.adj().array() * arena_b.val().array() / a_sq_plus_b_sq)
164+
.sum();
165+
arena_b.adj().array()
166+
+= -vi.adj().array() * arena_a.val() / a_sq_plus_b_sq;
167+
});
168+
} else if (!is_constant<Scalar>::value) {
169+
var arena_a = a;
170+
arena_t<promote_scalar_t<double, VarMat>> arena_b = value_of(b);
171+
auto a_sq_plus_b_sq = to_arena((arena_a.val() * arena_a.val())
172+
+ (arena_b.array() * arena_b.array()));
173+
174+
return make_callback_var(
175+
atan2(arena_a.val(), arena_b),
176+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
177+
arena_a.adj()
178+
+= (vi.adj().array() * arena_b.array() / a_sq_plus_b_sq).sum();
179+
});
180+
} else if (!is_constant<VarMat>::value) {
181+
double arena_a = value_of(a);
182+
arena_t<promote_scalar_t<var, VarMat>> arena_b = b;
183+
auto a_sq_plus_b_sq = to_arena(
184+
(arena_a * arena_a) + (arena_b.val().array() * arena_b.val().array()));
185+
186+
return make_callback_var(
187+
atan2(arena_a, arena_b.val()),
188+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
189+
arena_b.adj().array() += -vi.adj().array() * arena_a / a_sq_plus_b_sq;
190+
});
191+
}
192+
}
193+
194+
template <typename VarMat, typename Scalar,
195+
require_var_matrix_t<VarMat>* = nullptr,
196+
require_stan_scalar_t<Scalar>* = nullptr>
197+
inline auto atan2(const VarMat& a, const Scalar& b) {
198+
if (!is_constant<VarMat>::value && !is_constant<Scalar>::value) {
199+
arena_t<promote_scalar_t<var, VarMat>> arena_a = a;
200+
var arena_b = b;
201+
auto atan2_val = atan2(arena_a.val(), arena_b.val());
202+
auto a_sq_plus_b_sq
203+
= to_arena((arena_a.val().array() * arena_a.val().array())
204+
+ (arena_b.val() * arena_b.val()));
205+
return make_callback_var(
206+
atan2(arena_a.val(), arena_b.val()),
207+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
208+
arena_a.adj().array()
209+
+= vi.adj().array() * arena_b.val() / a_sq_plus_b_sq;
210+
arena_b.adj()
211+
+= -(vi.adj().array() * arena_a.val().array() / a_sq_plus_b_sq)
212+
.sum();
213+
});
214+
} else if (!is_constant<VarMat>::value) {
215+
arena_t<promote_scalar_t<var, VarMat>> arena_a = a;
216+
double arena_b = value_of(b);
217+
auto a_sq_plus_b_sq = to_arena(
218+
(arena_a.val().array() * arena_a.val().array()) + (arena_b * arena_b));
219+
220+
return make_callback_var(
221+
atan2(arena_a.val(), arena_b),
222+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
223+
arena_a.adj().array() += vi.adj().array() * arena_b / a_sq_plus_b_sq;
224+
});
225+
} else if (!is_constant<Scalar>::value) {
226+
arena_t<promote_scalar_t<double, VarMat>> arena_a = value_of(a);
227+
var arena_b = b;
228+
auto a_sq_plus_b_sq = to_arena((arena_a.array() * arena_a.array())
229+
+ (arena_b.val() * arena_b.val()));
230+
231+
return make_callback_var(
232+
atan2(arena_a, arena_b.val()),
233+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
234+
arena_b.adj()
235+
+= -(vi.adj().array() * arena_a.array() / a_sq_plus_b_sq).sum();
236+
});
237+
}
238+
}
239+
100240
} // namespace math
101241
} // namespace stan
102242
#endif

stan/math/rev/functor.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#define STAN_MATH_REV_FUNCTOR_HPP
33

44
#include <stan/math/rev/functor/algebra_solver_fp.hpp>
5-
#include <stan/math/rev/functor/algebra_solver_powell.hpp>
6-
#include <stan/math/rev/functor/algebra_solver_newton.hpp>
5+
#include <stan/math/rev/functor/solve_powell.hpp>
6+
#include <stan/math/rev/functor/solve_newton.hpp>
77
#include <stan/math/rev/functor/algebra_system.hpp>
88
#include <stan/math/rev/functor/apply_scalar_unary.hpp>
99
#include <stan/math/rev/functor/apply_scalar_binary.hpp>

0 commit comments

Comments
 (0)