Skip to content

Commit 39a2855

Browse files
authored
Merge pull request #2812 from stan-dev/vectorize-atan2
vectorize atan2
2 parents faffb4a + c9c5e9f commit 39a2855

4 files changed

Lines changed: 232 additions & 2 deletions

File tree

stan/math/prim/fun/atan2.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/prim/core.hpp>
55
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
67
#include <cmath>
78

89
namespace stan {
@@ -23,6 +24,23 @@ double atan2(T1 y, T2 x) {
2324
return std::atan2(y, x);
2425
}
2526

27+
/**
28+
* Enables the vectorised application of the atan2 function, when
29+
* the first and/or second arguments are containers.
30+
*
31+
* @tparam T1 type of first input
32+
* @tparam T2 type of second input
33+
* @param a First input
34+
* @param b Second input
35+
* @return Returns the atan2 function applied to the two inputs.
36+
*/
37+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
38+
require_all_not_var_matrix_t<T1, T2>* = nullptr>
39+
inline auto atan2(const T1& a, const T2& b) {
40+
return apply_scalar_binary(
41+
a, b, [](const auto& c, const auto& d) { return atan2(c, d); });
42+
}
43+
2644
} // namespace math
2745
} // namespace stan
2846

stan/math/rev/fun/atan2.hpp

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ inline var atan2(const var& a, const var& b) {
2929
std::atan2(a.val(), b.val()), [a, b](const auto& vi) mutable {
3030
double a_sq_plus_b_sq = (a.val() * a.val()) + (b.val() * b.val());
3131
a.adj() += vi.adj_ * b.val() / a_sq_plus_b_sq;
32-
b.adj() -= vi.adj_ * a.val() / a_sq_plus_b_sq;
32+
b.adj() += -vi.adj_ * a.val() / a_sq_plus_b_sq;
3333
});
3434
}
3535

@@ -93,10 +93,150 @@ inline var atan2(double a, const var& b) {
9393
return make_callback_var(
9494
std::atan2(a, b.val()), [a, b](const auto& vi) mutable {
9595
double a_sq_plus_b_sq = (a * a) + (b.val() * b.val());
96-
b.adj() -= vi.adj_ * a / a_sq_plus_b_sq;
96+
b.adj() += -vi.adj_ * a / a_sq_plus_b_sq;
9797
});
9898
}
9999

100+
template <typename Mat1, typename Mat2,
101+
require_any_var_matrix_t<Mat1, Mat2>* = nullptr,
102+
require_all_matrix_t<Mat1, Mat2>* = nullptr>
103+
inline auto atan2(const Mat1& a, const Mat2& b) {
104+
if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
105+
arena_t<promote_scalar_t<var, Mat1>> arena_a = a;
106+
arena_t<promote_scalar_t<var, Mat2>> arena_b = b;
107+
auto atan2_val = atan2(arena_a.val(), arena_b.val());
108+
auto a_sq_plus_b_sq
109+
= to_arena((arena_a.val().array() * arena_a.val().array())
110+
+ (arena_b.val().array() * arena_b.val().array()));
111+
return make_callback_var(
112+
atan2(arena_a.val(), arena_b.val()),
113+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
114+
arena_a.adj().array()
115+
+= vi.adj().array() * arena_b.val().array() / a_sq_plus_b_sq;
116+
arena_b.adj().array()
117+
+= -vi.adj().array() * arena_a.val().array() / a_sq_plus_b_sq;
118+
});
119+
} else if (!is_constant<Mat1>::value) {
120+
arena_t<promote_scalar_t<var, Mat1>> arena_a = a;
121+
arena_t<promote_scalar_t<double, Mat2>> arena_b = value_of(b);
122+
auto a_sq_plus_b_sq
123+
= to_arena((arena_a.val().array() * arena_a.val().array())
124+
+ (arena_b.array() * arena_b.array()));
125+
126+
return make_callback_var(
127+
atan2(arena_a.val(), arena_b),
128+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
129+
arena_a.adj().array()
130+
+= vi.adj().array() * arena_b.array() / a_sq_plus_b_sq;
131+
});
132+
} else if (!is_constant<Mat2>::value) {
133+
arena_t<promote_scalar_t<double, Mat1>> arena_a = value_of(a);
134+
arena_t<promote_scalar_t<var, Mat2>> arena_b = b;
135+
auto a_sq_plus_b_sq
136+
= to_arena((arena_a.array() * arena_a.array())
137+
+ (arena_b.val().array() * arena_b.val().array()));
138+
139+
return make_callback_var(
140+
atan2(arena_a, arena_b.val()),
141+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
142+
arena_b.adj().array()
143+
+= -vi.adj().array() * arena_a.array() / a_sq_plus_b_sq;
144+
});
145+
}
146+
}
147+
148+
template <typename Scalar, typename VarMat,
149+
require_var_matrix_t<VarMat>* = nullptr,
150+
require_stan_scalar_t<Scalar>* = nullptr>
151+
inline auto atan2(const Scalar& a, const VarMat& b) {
152+
if (!is_constant<Scalar>::value && !is_constant<VarMat>::value) {
153+
var arena_a = a;
154+
arena_t<promote_scalar_t<var, VarMat>> arena_b = b;
155+
auto atan2_val = atan2(arena_a.val(), arena_b.val());
156+
auto a_sq_plus_b_sq
157+
= to_arena((arena_a.val() * arena_a.val())
158+
+ (arena_b.val().array() * arena_b.val().array()));
159+
return make_callback_var(
160+
atan2(arena_a.val(), arena_b.val()),
161+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
162+
arena_a.adj()
163+
+= (vi.adj().array() * arena_b.val().array() / a_sq_plus_b_sq)
164+
.sum();
165+
arena_b.adj().array()
166+
+= -vi.adj().array() * arena_a.val() / a_sq_plus_b_sq;
167+
});
168+
} else if (!is_constant<Scalar>::value) {
169+
var arena_a = a;
170+
arena_t<promote_scalar_t<double, VarMat>> arena_b = value_of(b);
171+
auto a_sq_plus_b_sq = to_arena((arena_a.val() * arena_a.val())
172+
+ (arena_b.array() * arena_b.array()));
173+
174+
return make_callback_var(
175+
atan2(arena_a.val(), arena_b),
176+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
177+
arena_a.adj()
178+
+= (vi.adj().array() * arena_b.array() / a_sq_plus_b_sq).sum();
179+
});
180+
} else if (!is_constant<VarMat>::value) {
181+
double arena_a = value_of(a);
182+
arena_t<promote_scalar_t<var, VarMat>> arena_b = b;
183+
auto a_sq_plus_b_sq = to_arena(
184+
(arena_a * arena_a) + (arena_b.val().array() * arena_b.val().array()));
185+
186+
return make_callback_var(
187+
atan2(arena_a, arena_b.val()),
188+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
189+
arena_b.adj().array() += -vi.adj().array() * arena_a / a_sq_plus_b_sq;
190+
});
191+
}
192+
}
193+
194+
template <typename VarMat, typename Scalar,
195+
require_var_matrix_t<VarMat>* = nullptr,
196+
require_stan_scalar_t<Scalar>* = nullptr>
197+
inline auto atan2(const VarMat& a, const Scalar& b) {
198+
if (!is_constant<VarMat>::value && !is_constant<Scalar>::value) {
199+
arena_t<promote_scalar_t<var, VarMat>> arena_a = a;
200+
var arena_b = b;
201+
auto atan2_val = atan2(arena_a.val(), arena_b.val());
202+
auto a_sq_plus_b_sq
203+
= to_arena((arena_a.val().array() * arena_a.val().array())
204+
+ (arena_b.val() * arena_b.val()));
205+
return make_callback_var(
206+
atan2(arena_a.val(), arena_b.val()),
207+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
208+
arena_a.adj().array()
209+
+= vi.adj().array() * arena_b.val() / a_sq_plus_b_sq;
210+
arena_b.adj()
211+
+= -(vi.adj().array() * arena_a.val().array() / a_sq_plus_b_sq)
212+
.sum();
213+
});
214+
} else if (!is_constant<VarMat>::value) {
215+
arena_t<promote_scalar_t<var, VarMat>> arena_a = a;
216+
double arena_b = value_of(b);
217+
auto a_sq_plus_b_sq = to_arena(
218+
(arena_a.val().array() * arena_a.val().array()) + (arena_b * arena_b));
219+
220+
return make_callback_var(
221+
atan2(arena_a.val(), arena_b),
222+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
223+
arena_a.adj().array() += vi.adj().array() * arena_b / a_sq_plus_b_sq;
224+
});
225+
} else if (!is_constant<Scalar>::value) {
226+
arena_t<promote_scalar_t<double, VarMat>> arena_a = value_of(a);
227+
var arena_b = b;
228+
auto a_sq_plus_b_sq = to_arena((arena_a.array() * arena_a.array())
229+
+ (arena_b.val() * arena_b.val()));
230+
231+
return make_callback_var(
232+
atan2(arena_a, arena_b.val()),
233+
[arena_a, arena_b, a_sq_plus_b_sq](auto& vi) mutable {
234+
arena_b.adj()
235+
+= -(vi.adj().array() * arena_a.array() / a_sq_plus_b_sq).sum();
236+
});
237+
}
238+
}
239+
100240
} // namespace math
101241
} // namespace stan
102242
#endif

test/unit/math/mix/fun/atan2_test.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <test/unit/math/test_ad.hpp>
22
#include <limits>
3+
#include <vector>
34

45
TEST(mathMixCore, atan2) {
56
auto f = [](const auto& x1, const auto& x2) {
@@ -15,3 +16,40 @@ TEST(mathMixCore, atan2) {
1516
stan::test::expect_ad(f, 1.2, 3.9);
1617
stan::test::expect_ad(f, 0.5, 2.3);
1718
}
19+
20+
TEST(mathMixScalFun, atan2) {
21+
auto f = [](const auto& x1, const auto& x2) {
22+
using stan::math::atan2;
23+
return atan2(x1, x2);
24+
};
25+
26+
// finite differences fails for
27+
// infinite inputs
28+
// stan::test::expect_common_nonzero_binary(f);
29+
30+
stan::test::expect_ad(f, 1.0, 1.0);
31+
stan::test::expect_ad(f, 1.0, 0.5);
32+
stan::test::expect_ad(f, 1.2, 3.9);
33+
stan::test::expect_ad(f, 7.5, 1.8);
34+
35+
Eigen::VectorXd in1(3);
36+
in1 << 1.0, 1.0, 1.2;
37+
Eigen::VectorXd in2(3);
38+
in2 << 1.0, 0.5, 3.9;
39+
stan::test::expect_ad_vectorized_binary(f, in1, in2);
40+
}
41+
42+
TEST(mathMixScalFun, atan2_varmat) {
43+
auto f = [](const auto& x1, const auto& x2) {
44+
using stan::math::atan2;
45+
return atan2(x1, x2);
46+
};
47+
48+
Eigen::VectorXd in1(3);
49+
in1 << 0.5, 3.4, 5.2;
50+
Eigen::VectorXd in2(3);
51+
in2 << 3.3, 0.9, 6.7;
52+
stan::test::expect_ad_matvar(f, in1, in2);
53+
stan::test::expect_ad_matvar(f, in1(0), in2);
54+
stan::test::expect_ad_matvar(f, in1, in2(0));
55+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include <stan/math/prim.hpp>
2+
#include <test/unit/math/prim/fun/binary_scalar_tester.hpp>
3+
#include <gtest/gtest.h>
4+
#include <cmath>
5+
6+
TEST(MathFunctions, atan2) {
7+
using stan::math::atan2;
8+
9+
EXPECT_FLOAT_EQ(atan2(2.15, 1.71), 0.8988979010770248000345);
10+
EXPECT_FLOAT_EQ(atan2(7.62, 10.15), 0.6439738474911284019668);
11+
}
12+
13+
TEST(MathFunctions, atan2_nan) {
14+
using stan::math::atan2;
15+
using stan::math::INFTY;
16+
using stan::math::NOT_A_NUMBER;
17+
18+
EXPECT_TRUE(std::isnan(atan2(NOT_A_NUMBER, 2.16)));
19+
EXPECT_TRUE(std::isnan(atan2(1.65, NOT_A_NUMBER)));
20+
21+
EXPECT_FALSE(std::isnan(atan2(INFTY, 2.16)));
22+
EXPECT_FALSE(std::isnan(atan2(1.65, INFTY)));
23+
}
24+
25+
TEST(MathFunctions, atan2_vec) {
26+
auto f = [](const auto& x1, const auto& x2) {
27+
return stan::math::atan2(x1, x2);
28+
};
29+
30+
Eigen::VectorXd in1 = Eigen::VectorXd::Random(6);
31+
Eigen::VectorXd in2 = Eigen::VectorXd::Random(6);
32+
33+
stan::test::binary_scalar_tester(f, in1, in2);
34+
}

0 commit comments

Comments
 (0)