@@ -29,7 +29,7 @@ inline var atan2(const var& a, const var& b) {
2929 std::atan2 (a.val (), b.val ()), [a, b](const auto & vi) mutable {
3030 double a_sq_plus_b_sq = (a.val () * a.val ()) + (b.val () * b.val ());
3131 a.adj () += vi.adj_ * b.val () / a_sq_plus_b_sq;
32- b.adj () -= vi.adj_ * a.val () / a_sq_plus_b_sq;
32+ b.adj () += - vi.adj_ * a.val () / a_sq_plus_b_sq;
3333 });
3434}
3535
@@ -93,10 +93,150 @@ inline var atan2(double a, const var& b) {
9393 return make_callback_var (
9494 std::atan2 (a, b.val ()), [a, b](const auto & vi) mutable {
9595 double a_sq_plus_b_sq = (a * a) + (b.val () * b.val ());
96- b.adj () -= vi.adj_ * a / a_sq_plus_b_sq;
96+ b.adj () += - vi.adj_ * a / a_sq_plus_b_sq;
9797 });
9898}
9999
100+ template <typename Mat1, typename Mat2,
101+ require_any_var_matrix_t <Mat1, Mat2>* = nullptr ,
102+ require_all_matrix_t <Mat1, Mat2>* = nullptr >
103+ inline auto atan2 (const Mat1& a, const Mat2& b) {
104+ if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
105+ arena_t <promote_scalar_t <var, Mat1>> arena_a = a;
106+ arena_t <promote_scalar_t <var, Mat2>> arena_b = b;
107+ auto atan2_val = atan2 (arena_a.val (), arena_b.val ());
108+ auto a_sq_plus_b_sq
109+ = to_arena ((arena_a.val ().array () * arena_a.val ().array ())
110+ + (arena_b.val ().array () * arena_b.val ().array ()));
111+ return make_callback_var (
112+ atan2 (arena_a.val (), arena_b.val ()),
113+ [arena_a, arena_b, a_sq_plus_b_sq](auto & vi) mutable {
114+ arena_a.adj ().array ()
115+ += vi.adj ().array () * arena_b.val ().array () / a_sq_plus_b_sq;
116+ arena_b.adj ().array ()
117+ += -vi.adj ().array () * arena_a.val ().array () / a_sq_plus_b_sq;
118+ });
119+ } else if (!is_constant<Mat1>::value) {
120+ arena_t <promote_scalar_t <var, Mat1>> arena_a = a;
121+ arena_t <promote_scalar_t <double , Mat2>> arena_b = value_of (b);
122+ auto a_sq_plus_b_sq
123+ = to_arena ((arena_a.val ().array () * arena_a.val ().array ())
124+ + (arena_b.array () * arena_b.array ()));
125+
126+ return make_callback_var (
127+ atan2 (arena_a.val (), arena_b),
128+ [arena_a, arena_b, a_sq_plus_b_sq](auto & vi) mutable {
129+ arena_a.adj ().array ()
130+ += vi.adj ().array () * arena_b.array () / a_sq_plus_b_sq;
131+ });
132+ } else if (!is_constant<Mat2>::value) {
133+ arena_t <promote_scalar_t <double , Mat1>> arena_a = value_of (a);
134+ arena_t <promote_scalar_t <var, Mat2>> arena_b = b;
135+ auto a_sq_plus_b_sq
136+ = to_arena ((arena_a.array () * arena_a.array ())
137+ + (arena_b.val ().array () * arena_b.val ().array ()));
138+
139+ return make_callback_var (
140+ atan2 (arena_a, arena_b.val ()),
141+ [arena_a, arena_b, a_sq_plus_b_sq](auto & vi) mutable {
142+ arena_b.adj ().array ()
143+ += -vi.adj ().array () * arena_a.array () / a_sq_plus_b_sq;
144+ });
145+ }
146+ }
147+
148+ template <typename Scalar, typename VarMat,
149+ require_var_matrix_t <VarMat>* = nullptr ,
150+ require_stan_scalar_t <Scalar>* = nullptr >
151+ inline auto atan2 (const Scalar& a, const VarMat& b) {
152+ if (!is_constant<Scalar>::value && !is_constant<VarMat>::value) {
153+ var arena_a = a;
154+ arena_t <promote_scalar_t <var, VarMat>> arena_b = b;
155+ auto atan2_val = atan2 (arena_a.val (), arena_b.val ());
156+ auto a_sq_plus_b_sq
157+ = to_arena ((arena_a.val () * arena_a.val ())
158+ + (arena_b.val ().array () * arena_b.val ().array ()));
159+ return make_callback_var (
160+ atan2 (arena_a.val (), arena_b.val ()),
161+ [arena_a, arena_b, a_sq_plus_b_sq](auto & vi) mutable {
162+ arena_a.adj ()
163+ += (vi.adj ().array () * arena_b.val ().array () / a_sq_plus_b_sq)
164+ .sum ();
165+ arena_b.adj ().array ()
166+ += -vi.adj ().array () * arena_a.val () / a_sq_plus_b_sq;
167+ });
168+ } else if (!is_constant<Scalar>::value) {
169+ var arena_a = a;
170+ arena_t <promote_scalar_t <double , VarMat>> arena_b = value_of (b);
171+ auto a_sq_plus_b_sq = to_arena ((arena_a.val () * arena_a.val ())
172+ + (arena_b.array () * arena_b.array ()));
173+
174+ return make_callback_var (
175+ atan2 (arena_a.val (), arena_b),
176+ [arena_a, arena_b, a_sq_plus_b_sq](auto & vi) mutable {
177+ arena_a.adj ()
178+ += (vi.adj ().array () * arena_b.array () / a_sq_plus_b_sq).sum ();
179+ });
180+ } else if (!is_constant<VarMat>::value) {
181+ double arena_a = value_of (a);
182+ arena_t <promote_scalar_t <var, VarMat>> arena_b = b;
183+ auto a_sq_plus_b_sq = to_arena (
184+ (arena_a * arena_a) + (arena_b.val ().array () * arena_b.val ().array ()));
185+
186+ return make_callback_var (
187+ atan2 (arena_a, arena_b.val ()),
188+ [arena_a, arena_b, a_sq_plus_b_sq](auto & vi) mutable {
189+ arena_b.adj ().array () += -vi.adj ().array () * arena_a / a_sq_plus_b_sq;
190+ });
191+ }
192+ }
193+
194+ template <typename VarMat, typename Scalar,
195+ require_var_matrix_t <VarMat>* = nullptr ,
196+ require_stan_scalar_t <Scalar>* = nullptr >
197+ inline auto atan2 (const VarMat& a, const Scalar& b) {
198+ if (!is_constant<VarMat>::value && !is_constant<Scalar>::value) {
199+ arena_t <promote_scalar_t <var, VarMat>> arena_a = a;
200+ var arena_b = b;
201+ auto atan2_val = atan2 (arena_a.val (), arena_b.val ());
202+ auto a_sq_plus_b_sq
203+ = to_arena ((arena_a.val ().array () * arena_a.val ().array ())
204+ + (arena_b.val () * arena_b.val ()));
205+ return make_callback_var (
206+ atan2 (arena_a.val (), arena_b.val ()),
207+ [arena_a, arena_b, a_sq_plus_b_sq](auto & vi) mutable {
208+ arena_a.adj ().array ()
209+ += vi.adj ().array () * arena_b.val () / a_sq_plus_b_sq;
210+ arena_b.adj ()
211+ += -(vi.adj ().array () * arena_a.val ().array () / a_sq_plus_b_sq)
212+ .sum ();
213+ });
214+ } else if (!is_constant<VarMat>::value) {
215+ arena_t <promote_scalar_t <var, VarMat>> arena_a = a;
216+ double arena_b = value_of (b);
217+ auto a_sq_plus_b_sq = to_arena (
218+ (arena_a.val ().array () * arena_a.val ().array ()) + (arena_b * arena_b));
219+
220+ return make_callback_var (
221+ atan2 (arena_a.val (), arena_b),
222+ [arena_a, arena_b, a_sq_plus_b_sq](auto & vi) mutable {
223+ arena_a.adj ().array () += vi.adj ().array () * arena_b / a_sq_plus_b_sq;
224+ });
225+ } else if (!is_constant<Scalar>::value) {
226+ arena_t <promote_scalar_t <double , VarMat>> arena_a = value_of (a);
227+ var arena_b = b;
228+ auto a_sq_plus_b_sq = to_arena ((arena_a.array () * arena_a.array ())
229+ + (arena_b.val () * arena_b.val ()));
230+
231+ return make_callback_var (
232+ atan2 (arena_a, arena_b.val ()),
233+ [arena_a, arena_b, a_sq_plus_b_sq](auto & vi) mutable {
234+ arena_b.adj ()
235+ += -(vi.adj ().array () * arena_a.array () / a_sq_plus_b_sq).sum ();
236+ });
237+ }
238+ }
239+
100240} // namespace math
101241} // namespace stan
102242#endif
0 commit comments