55#include < stan/math/rev/core.hpp>
66#include < stan/math/prim/fun/constants.hpp>
77#include < stan/math/prim/fun/erf.hpp>
8+ #include < stan/math/prim/fun/eval.hpp>
89#include < stan/math/prim/fun/owens_t.hpp>
910#include < stan/math/prim/fun/square.hpp>
1011#include < cmath>
1112
1213namespace stan {
1314namespace math {
1415
15- namespace internal {
16- class owens_t_vv_vari : public op_vv_vari {
17- public:
18- owens_t_vv_vari (vari* avi, vari* bvi)
19- : op_vv_vari(owens_t (avi->val_, bvi->val_), avi, bvi) {}
20- void chain () {
21- const double neg_avi_sq_div_2 = -square (avi_->val_ ) * 0.5 ;
22- const double one_p_bvi_sq = 1.0 + square (bvi_->val_ );
23-
24- avi_->adj_ += adj_ * erf (bvi_->val_ * avi_->val_ * INV_SQRT_TWO)
25- * std::exp (neg_avi_sq_div_2) * INV_SQRT_TWO_PI * -0.5 ;
26- bvi_->adj_ += adj_ * std::exp (neg_avi_sq_div_2 * one_p_bvi_sq)
27- / (one_p_bvi_sq * TWO_PI);
28- }
29- };
30-
31- class owens_t_vd_vari : public op_vd_vari {
32- public:
33- owens_t_vd_vari (vari* avi, double b)
34- : op_vd_vari(owens_t (avi->val_, b), avi, b) {}
35- void chain () {
36- avi_->adj_ += adj_ * erf (bd_ * avi_->val_ * INV_SQRT_TWO)
37- * std::exp (-square (avi_->val_ ) * 0.5 ) * INV_SQRT_TWO_PI
38- * -0.5 ;
39- }
40- };
41-
42- class owens_t_dv_vari : public op_dv_vari {
43- public:
44- owens_t_dv_vari (double a, vari* bvi)
45- : op_dv_vari(owens_t (a, bvi->val_), a, bvi) {}
46- void chain () {
47- const double one_p_bvi_sq = 1.0 + square (bvi_->val_ );
48- bvi_->adj_ += adj_ * std::exp (-0.5 * square (ad_) * one_p_bvi_sq)
49- / (one_p_bvi_sq * TWO_PI);
50- }
51- };
52- } // namespace internal
53-
5416/* *
5517 * The Owen's T function of h and a.
5618 *
5719 * Used to compute the cumulative density function for the skew normal
5820 * distribution.
5921 *
22+ * @tparam Var1 A scalar or Eigen type whose `scalar_type` is an var.
23+ * @tparam Var2 A scalar or Eigen type whose `scalar_type` is an var.
6024 * @param h var parameter.
6125 * @param a var parameter.
6226 * @return The Owen's T function.
6327 */
64- inline var owens_t (const var& h, const var& a) {
65- return var (new internal::owens_t_vv_vari (h.vi_ , a.vi_ ));
28+ template <typename Var1, typename Var2,
29+ require_all_st_var<Var1, Var2>* = nullptr ,
30+ require_all_not_std_vector_t <Var1, Var2>* = nullptr >
31+ inline auto owens_t (const Var1& h, const Var2& a) {
32+ auto h_arena = to_arena (h);
33+ auto a_arena = to_arena (a);
34+ using return_type
35+ = return_var_matrix_t <decltype (owens_t (h_arena.val (), a_arena.val ())),
36+ Var1, Var2>;
37+ arena_t <return_type> ret = owens_t (h_arena.val (), a_arena.val ());
38+ reverse_pass_callback ([h_arena, a_arena, ret]() mutable {
39+ const auto & h_val = as_value_array_or_scalar (h_arena);
40+ const auto & a_val = as_value_array_or_scalar (a_arena);
41+ const auto neg_h_sq_div_2 = stan::math::eval (-square (h_val) * 0.5 );
42+ const auto one_p_a_sq = stan::math::eval (1.0 + square (a_val));
43+ as_array_or_scalar (h_arena).adj () += possibly_sum<is_stan_scalar<Var1>>(
44+ as_array_or_scalar (ret.adj ()) * erf (a_val * h_val * INV_SQRT_TWO)
45+ * exp (neg_h_sq_div_2) * INV_SQRT_TWO_PI * -0.5 );
46+ as_array_or_scalar (a_arena).adj () += possibly_sum<is_stan_scalar<Var2>>(
47+ as_array_or_scalar (ret.adj ()) * exp (neg_h_sq_div_2 * one_p_a_sq)
48+ / (one_p_a_sq * TWO_PI));
49+ });
50+ return return_type (ret);
6651}
6752
6853/* *
@@ -71,12 +56,30 @@ inline var owens_t(const var& h, const var& a) {
7156 * Used to compute the cumulative density function for the skew normal
7257 * distribution.
7358 *
59+ * @tparam Var A scalar or Eigen type whose `scalar_type` is an var.
60+ * @tparam Arith A scalar or Eigen type with an inner arirthmetic scalar value.
7461 * @param h var parameter.
7562 * @param a double parameter.
7663 * @return The Owen's T function.
7764 */
78- inline var owens_t (const var& h, double a) {
79- return var (new internal::owens_t_vd_vari (h.vi_ , a));
65+ template <typename Var, typename Arith, require_st_arithmetic<Arith>* = nullptr ,
66+ require_all_not_std_vector_t <Var, Arith>* = nullptr ,
67+ require_st_var<Var>* = nullptr >
68+ inline auto owens_t (const Var& h, const Arith& a) {
69+ auto h_arena = to_arena (h);
70+ auto a_arena = to_arena (a);
71+ using return_type
72+ = return_var_matrix_t <decltype (owens_t (h_arena.val (), a_arena)), Var,
73+ Arith>;
74+ arena_t <return_type> ret = owens_t (h_arena.val (), a_arena);
75+ reverse_pass_callback ([h_arena, a_arena, ret]() mutable {
76+ const auto & h_val = as_value_array_or_scalar (h_arena);
77+ as_array_or_scalar (h_arena).adj () += possibly_sum<is_stan_scalar<Var>>(
78+ as_array_or_scalar (ret.adj ())
79+ * erf (as_array_or_scalar (a_arena) * h_val * INV_SQRT_TWO)
80+ * exp (-square (h_val) * 0.5 ) * INV_SQRT_TWO_PI * -0.5 );
81+ });
82+ return return_type (ret);
8083}
8184
8285/* *
@@ -85,12 +88,31 @@ inline var owens_t(const var& h, double a) {
8588 * Used to compute the cumulative density function for the skew normal
8689 * distribution.
8790 *
91+ * @tparam Var A scalar or Eigen type whose `scalar_type` is an var.
92+ * @tparam Arith A scalar or Eigen type with an inner arithmetic scalar value.
8893 * @param h double parameter.
8994 * @param a var parameter.
9095 * @return The Owen's T function.
9196 */
92- inline var owens_t (double h, const var& a) {
93- return var (new internal::owens_t_dv_vari (h, a.vi_ ));
97+ template <typename Arith, typename Var, require_st_arithmetic<Arith>* = nullptr ,
98+ require_all_not_std_vector_t <Var, Arith>* = nullptr ,
99+ require_st_var<Var>* = nullptr >
100+ inline auto owens_t (const Arith& h, const Var& a) {
101+ auto h_arena = to_arena (h);
102+ auto a_arena = to_arena (a);
103+ using return_type
104+ = return_var_matrix_t <decltype (owens_t (h_arena, a_arena.val ())), Var,
105+ Arith>;
106+ arena_t <return_type> ret = owens_t (h_arena, a_arena.val ());
107+ reverse_pass_callback ([h_arena, a_arena, ret]() mutable {
108+ const auto one_p_a_sq
109+ = eval (1.0 + square (as_value_array_or_scalar (a_arena)));
110+ as_array_or_scalar (a_arena).adj () += possibly_sum<is_stan_scalar<Var>>(
111+ as_array_or_scalar (ret.adj ())
112+ * exp (-0.5 * square (as_array_or_scalar (h_arena)) * one_p_a_sq)
113+ / (one_p_a_sq * TWO_PI));
114+ });
115+ return return_type (ret);
94116}
95117
96118} // namespace math
0 commit comments