Skip to content

Commit 231cbd0

Browse files
authored
Merge pull request #2787 from stan-dev/feature/owens-t-varmat
Allow varmat for owens_t
2 parents d3b129e + a9a9c23 commit 231cbd0

6 files changed

Lines changed: 169 additions & 47 deletions

File tree

stan/math/prim/fun/owens_t.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@ inline double owens_t(double h, double a) { return boost::math::owens_t(h, a); }
6767
* @param b Second input
6868
* @return owens_t function applied to the two inputs.
6969
*/
70-
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
70+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
71+
require_all_not_var_and_matrix_types<T1, T2>* = nullptr>
7172
inline auto owens_t(const T1& a, const T2& b) {
7273
return apply_scalar_binary(
73-
a, b, [&](const auto& c, const auto& d) { return owens_t(c, d); });
74+
a, b, [](const auto& c, const auto& d) { return owens_t(c, d); });
7475
}
7576

7677
} // namespace math

stan/math/prim/meta.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@
207207
#include <stan/math/prim/meta/is_string_convertible.hpp>
208208
#include <stan/math/prim/meta/is_tuple.hpp>
209209
#include <stan/math/prim/meta/is_var.hpp>
210+
#include <stan/math/prim/meta/is_var_and_matrix_types.hpp>
210211
#include <stan/math/prim/meta/is_var_matrix.hpp>
211212
#include <stan/math/prim/meta/is_var_dense_dynamic.hpp>
212213
#include <stan/math/prim/meta/is_var_eigen.hpp>
@@ -220,6 +221,7 @@
220221
#include <stan/math/prim/meta/partials_return_type.hpp>
221222
#include <stan/math/prim/meta/partials_type.hpp>
222223
#include <stan/math/prim/meta/plain_type.hpp>
224+
#include <stan/math/prim/meta/possibly_sum.hpp>
223225
#include <stan/math/prim/meta/promote_args.hpp>
224226
#include <stan/math/prim/meta/promote_scalar_type.hpp>
225227
#include <stan/math/prim/meta/ref_type.hpp>
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef STAN_MATH_PRIM_META_IS_VAR_AND_MATRIX_TYPES_HPP
2+
#define STAN_MATH_PRIM_META_IS_VAR_AND_MATRIX_TYPES_HPP
3+
4+
#include <stan/math/prim/meta/disjunction.hpp>
5+
#include <stan/math/prim/meta/is_var.hpp>
6+
#include <stan/math/prim/meta/is_matrix.hpp>
7+
#include <stan/math/prim/meta/require_helpers.hpp>
8+
#include <stan/math/prim/meta/return_type.hpp>
9+
10+
namespace stan {
11+
12+
/** \ingroup type_trait
13+
* Extends std::true_type when instantiated with at least one type that has a
14+
* var `scalar_type` and at least one type is a matrix. Extends std::false_type
15+
* otherwise.
16+
* @tparam Types Types to test
17+
*/
18+
template <typename... Types>
19+
using is_var_and_matrix_types
20+
= bool_constant<is_var<return_type_t<Types...>>::value
21+
&& stan::math::disjunction<is_matrix<Types>...>::value>;
22+
23+
template <typename... Types>
24+
using require_all_var_and_matrix_types
25+
= require_t<is_var_and_matrix_types<Types...>>;
26+
27+
template <typename... Types>
28+
using require_all_not_var_and_matrix_types
29+
= require_not_t<is_var_and_matrix_types<Types...>>;
30+
31+
} // namespace stan
32+
#endif
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef STAN_MATH_PRIM_META_POSSIBLY_SUM_HPP
2+
#define STAN_MATH_PRIM_META_POSSIBLY_SUM_HPP
3+
4+
#include <stan/math/prim/fun/sum.hpp>
5+
#include <stan/math/prim/meta/require_helpers.hpp>
6+
7+
namespace stan {
8+
namespace math {
9+
10+
/**
11+
* Conditionally sum the input at compile time.
12+
* @tparam CondSum A struct with a static boolean member `value` which if true
13+
* will allow the input value to be summed
14+
* @tparam T A scalar, Eigen type, or standard vector with inner scalar type.
15+
* @param x The value to be summed.
16+
*/
17+
template <typename CondSum, typename T, require_t<CondSum>* = nullptr>
18+
inline auto possibly_sum(T&& x) {
19+
return sum(std::forward<T>(x));
20+
}
21+
22+
/**
23+
* Conditionally sum the input at compile time. This overload does not sum.
24+
* @tparam CondSum A struct with a static boolean member `value` which if false
25+
* will pass the input to the output.
26+
* @tparam T A scalar, Eigen type, or standard vector with inner scalar type.
27+
* @param x The value to be passed trhough.
28+
*/
29+
template <typename CondSum, typename T1, require_not_t<CondSum>* = nullptr>
30+
inline auto possibly_sum(T1&& x) {
31+
return std::forward<T1>(x);
32+
}
33+
34+
} // namespace math
35+
} // namespace stan
36+
37+
#endif

stan/math/rev/fun/owens_t.hpp

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,64 +5,49 @@
55
#include <stan/math/rev/core.hpp>
66
#include <stan/math/prim/fun/constants.hpp>
77
#include <stan/math/prim/fun/erf.hpp>
8+
#include <stan/math/prim/fun/eval.hpp>
89
#include <stan/math/prim/fun/owens_t.hpp>
910
#include <stan/math/prim/fun/square.hpp>
1011
#include <cmath>
1112

1213
namespace stan {
1314
namespace math {
1415

15-
namespace internal {
16-
class owens_t_vv_vari : public op_vv_vari {
17-
public:
18-
owens_t_vv_vari(vari* avi, vari* bvi)
19-
: op_vv_vari(owens_t(avi->val_, bvi->val_), avi, bvi) {}
20-
void chain() {
21-
const double neg_avi_sq_div_2 = -square(avi_->val_) * 0.5;
22-
const double one_p_bvi_sq = 1.0 + square(bvi_->val_);
23-
24-
avi_->adj_ += adj_ * erf(bvi_->val_ * avi_->val_ * INV_SQRT_TWO)
25-
* std::exp(neg_avi_sq_div_2) * INV_SQRT_TWO_PI * -0.5;
26-
bvi_->adj_ += adj_ * std::exp(neg_avi_sq_div_2 * one_p_bvi_sq)
27-
/ (one_p_bvi_sq * TWO_PI);
28-
}
29-
};
30-
31-
class owens_t_vd_vari : public op_vd_vari {
32-
public:
33-
owens_t_vd_vari(vari* avi, double b)
34-
: op_vd_vari(owens_t(avi->val_, b), avi, b) {}
35-
void chain() {
36-
avi_->adj_ += adj_ * erf(bd_ * avi_->val_ * INV_SQRT_TWO)
37-
* std::exp(-square(avi_->val_) * 0.5) * INV_SQRT_TWO_PI
38-
* -0.5;
39-
}
40-
};
41-
42-
class owens_t_dv_vari : public op_dv_vari {
43-
public:
44-
owens_t_dv_vari(double a, vari* bvi)
45-
: op_dv_vari(owens_t(a, bvi->val_), a, bvi) {}
46-
void chain() {
47-
const double one_p_bvi_sq = 1.0 + square(bvi_->val_);
48-
bvi_->adj_ += adj_ * std::exp(-0.5 * square(ad_) * one_p_bvi_sq)
49-
/ (one_p_bvi_sq * TWO_PI);
50-
}
51-
};
52-
} // namespace internal
53-
5416
/**
5517
* The Owen's T function of h and a.
5618
*
5719
* Used to compute the cumulative density function for the skew normal
5820
* distribution.
5921
*
22+
* @tparam Var1 A scalar or Eigen type whose `scalar_type` is an var.
23+
* @tparam Var2 A scalar or Eigen type whose `scalar_type` is an var.
6024
* @param h var parameter.
6125
* @param a var parameter.
6226
* @return The Owen's T function.
6327
*/
64-
inline var owens_t(const var& h, const var& a) {
65-
return var(new internal::owens_t_vv_vari(h.vi_, a.vi_));
28+
template <typename Var1, typename Var2,
29+
require_all_st_var<Var1, Var2>* = nullptr,
30+
require_all_not_std_vector_t<Var1, Var2>* = nullptr>
31+
inline auto owens_t(const Var1& h, const Var2& a) {
32+
auto h_arena = to_arena(h);
33+
auto a_arena = to_arena(a);
34+
using return_type
35+
= return_var_matrix_t<decltype(owens_t(h_arena.val(), a_arena.val())),
36+
Var1, Var2>;
37+
arena_t<return_type> ret = owens_t(h_arena.val(), a_arena.val());
38+
reverse_pass_callback([h_arena, a_arena, ret]() mutable {
39+
const auto& h_val = as_value_array_or_scalar(h_arena);
40+
const auto& a_val = as_value_array_or_scalar(a_arena);
41+
const auto neg_h_sq_div_2 = stan::math::eval(-square(h_val) * 0.5);
42+
const auto one_p_a_sq = stan::math::eval(1.0 + square(a_val));
43+
as_array_or_scalar(h_arena).adj() += possibly_sum<is_stan_scalar<Var1>>(
44+
as_array_or_scalar(ret.adj()) * erf(a_val * h_val * INV_SQRT_TWO)
45+
* exp(neg_h_sq_div_2) * INV_SQRT_TWO_PI * -0.5);
46+
as_array_or_scalar(a_arena).adj() += possibly_sum<is_stan_scalar<Var2>>(
47+
as_array_or_scalar(ret.adj()) * exp(neg_h_sq_div_2 * one_p_a_sq)
48+
/ (one_p_a_sq * TWO_PI));
49+
});
50+
return return_type(ret);
6651
}
6752

6853
/**
@@ -71,12 +56,30 @@ inline var owens_t(const var& h, const var& a) {
7156
* Used to compute the cumulative density function for the skew normal
7257
* distribution.
7358
*
59+
* @tparam Var A scalar or Eigen type whose `scalar_type` is an var.
60+
* @tparam Arith A scalar or Eigen type with an inner arirthmetic scalar value.
7461
* @param h var parameter.
7562
* @param a double parameter.
7663
* @return The Owen's T function.
7764
*/
78-
inline var owens_t(const var& h, double a) {
79-
return var(new internal::owens_t_vd_vari(h.vi_, a));
65+
template <typename Var, typename Arith, require_st_arithmetic<Arith>* = nullptr,
66+
require_all_not_std_vector_t<Var, Arith>* = nullptr,
67+
require_st_var<Var>* = nullptr>
68+
inline auto owens_t(const Var& h, const Arith& a) {
69+
auto h_arena = to_arena(h);
70+
auto a_arena = to_arena(a);
71+
using return_type
72+
= return_var_matrix_t<decltype(owens_t(h_arena.val(), a_arena)), Var,
73+
Arith>;
74+
arena_t<return_type> ret = owens_t(h_arena.val(), a_arena);
75+
reverse_pass_callback([h_arena, a_arena, ret]() mutable {
76+
const auto& h_val = as_value_array_or_scalar(h_arena);
77+
as_array_or_scalar(h_arena).adj() += possibly_sum<is_stan_scalar<Var>>(
78+
as_array_or_scalar(ret.adj())
79+
* erf(as_array_or_scalar(a_arena) * h_val * INV_SQRT_TWO)
80+
* exp(-square(h_val) * 0.5) * INV_SQRT_TWO_PI * -0.5);
81+
});
82+
return return_type(ret);
8083
}
8184

8285
/**
@@ -85,12 +88,31 @@ inline var owens_t(const var& h, double a) {
8588
* Used to compute the cumulative density function for the skew normal
8689
* distribution.
8790
*
91+
* @tparam Var A scalar or Eigen type whose `scalar_type` is an var.
92+
* @tparam Arith A scalar or Eigen type with an inner arithmetic scalar value.
8893
* @param h double parameter.
8994
* @param a var parameter.
9095
* @return The Owen's T function.
9196
*/
92-
inline var owens_t(double h, const var& a) {
93-
return var(new internal::owens_t_dv_vari(h, a.vi_));
97+
template <typename Arith, typename Var, require_st_arithmetic<Arith>* = nullptr,
98+
require_all_not_std_vector_t<Var, Arith>* = nullptr,
99+
require_st_var<Var>* = nullptr>
100+
inline auto owens_t(const Arith& h, const Var& a) {
101+
auto h_arena = to_arena(h);
102+
auto a_arena = to_arena(a);
103+
using return_type
104+
= return_var_matrix_t<decltype(owens_t(h_arena, a_arena.val())), Var,
105+
Arith>;
106+
arena_t<return_type> ret = owens_t(h_arena, a_arena.val());
107+
reverse_pass_callback([h_arena, a_arena, ret]() mutable {
108+
const auto one_p_a_sq
109+
= eval(1.0 + square(as_value_array_or_scalar(a_arena)));
110+
as_array_or_scalar(a_arena).adj() += possibly_sum<is_stan_scalar<Var>>(
111+
as_array_or_scalar(ret.adj())
112+
* exp(-0.5 * square(as_array_or_scalar(h_arena)) * one_p_a_sq)
113+
/ (one_p_a_sq * TWO_PI));
114+
});
115+
return return_type(ret);
94116
}
95117

96118
} // namespace math

test/unit/math/mix/fun/owens_t_test.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,21 @@ TEST(mathMixScalFun, owensT) {
1515
stan::test::expect_ad(f, nan, nan);
1616
}
1717

18+
TEST(mathMixScalFun, owensT_varmat) {
19+
auto f = [](const auto& x1, const auto& x2) {
20+
return stan::math::owens_t(x1, x2);
21+
};
22+
double scal = 2.0;
23+
Eigen::MatrixXd mat = Eigen::MatrixXd::Random(2, 2);
24+
Eigen::VectorXd vec = Eigen::VectorXd::Random(2);
25+
stan::test::expect_ad_matvar(f, mat, mat);
26+
stan::test::expect_ad_matvar(f, mat, scal);
27+
stan::test::expect_ad_matvar(f, scal, mat);
28+
stan::test::expect_ad_matvar(f, vec, vec);
29+
stan::test::expect_ad_matvar(f, vec, scal);
30+
stan::test::expect_ad_matvar(f, scal, vec);
31+
}
32+
1833
TEST(mathMixScalFun, owensT_vec) {
1934
auto f = [](const auto& x1, const auto& x2) {
2035
using stan::math::owens_t;
@@ -27,3 +42,16 @@ TEST(mathMixScalFun, owensT_vec) {
2742
in2 << 3.0, 4.0;
2843
stan::test::expect_ad_vectorized_binary(f, in1, in2);
2944
}
45+
46+
TEST(mathMixScalFun, owensT_vec_matvar) {
47+
auto f = [](const auto& x1, const auto& x2) {
48+
using stan::math::owens_t;
49+
return owens_t(x1, x2);
50+
};
51+
52+
Eigen::MatrixXd in1(2, 2);
53+
in1 << 0.5, 3.4, 5.2, 0.5;
54+
Eigen::MatrixXd in2(2, 2);
55+
in2 << 3.3, 0.9, 6.7, 3.3;
56+
stan::test::expect_ad_vectorized_matvar(f, in1, in2);
57+
}

0 commit comments

Comments
 (0)