Skip to content

Commit efbc688

Browse files
authored
Merge pull request #2929 from stan-dev/fvar-support
Framework for generic fvar<T> support through finite-differences
2 parents 11b3aff + ad303aa commit efbc688

11 files changed

Lines changed: 291 additions & 38 deletions

File tree

stan/math/fwd/functor.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
#include <stan/math/fwd/functor/apply_scalar_unary.hpp>
55
#include <stan/math/fwd/functor/gradient.hpp>
6+
#include <stan/math/fwd/functor/finite_diff.hpp>
67
#include <stan/math/fwd/functor/hessian.hpp>
8+
#include <stan/math/fwd/functor/integrate_1d.hpp>
79
#include <stan/math/fwd/functor/jacobian.hpp>
810
#include <stan/math/fwd/functor/operands_and_partials.hpp>
911
#include <stan/math/fwd/functor/partials_propagator.hpp>
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#ifndef STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP
2+
#define STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
6+
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp>
7+
#include <stan/math/prim/fun/value_of.hpp>
8+
#include <stan/math/prim/fun/sum.hpp>
9+
#include <stan/math/prim/fun/serializer.hpp>
10+
11+
namespace stan {
12+
namespace math {
13+
namespace internal {
14+
/**
15+
* Helper function for aggregating tangents if the respective input argument
16+
* was an fvar<T> type.
17+
*
18+
* Overload for when the input is not an fvar<T> and no tangents are needed.
19+
*
20+
* @tparam FuncTangent Type of tangent calculated by finite-differences
21+
* @tparam InputArg Type of the function input argument
22+
* @param tangent Calculated tangent
23+
* @param arg Input argument
24+
*/
25+
template <typename FuncTangent, typename InputArg,
26+
require_not_st_fvar<InputArg>* = nullptr>
27+
inline constexpr double aggregate_tangent(const FuncTangent& tangent,
28+
const InputArg& arg) {
29+
return 0;
30+
}
31+
32+
/**
33+
* Helper function for aggregating tangents if the respective input argument
34+
* was an fvar<T> type.
35+
*
36+
* Overload for when the input is an fvar<T> and its tangent needs to be
37+
* aggregated.
38+
*
39+
* @tparam FuncTangent Type of tangent calculated by finite-differences
40+
* @tparam InputArg Type of the function input argument
41+
* @param tangent Calculated tangent
42+
* @param arg Input argument
43+
*/
44+
template <typename FuncTangent, typename InputArg,
45+
require_st_fvar<InputArg>* = nullptr>
46+
inline auto aggregate_tangent(const FuncTangent& tangent, const InputArg& arg) {
47+
return sum(apply_scalar_binary(
48+
tangent, arg, [](const auto& x, const auto& y) { return x * y.d_; }));
49+
}
50+
} // namespace internal
51+
52+
/**
53+
* Construct an fvar<T> where the tangent is calculated by finite-differencing.
54+
* Finite-differencing is only perfomed where the scalar type to be evaluated is
55+
* `fvar<T>.
56+
*
57+
* Higher-order inputs (i.e., fvar<var> & fvar<fvar<T>>) are also implicitly
58+
* supported through auto-diffing the finite-differencing process.
59+
*
60+
* @tparam F Type of functor for which fvar<T> support is needed
61+
* @tparam TArgs Template parameter pack of the types passed in the `operator()`
62+
* of the functor type `F`. Must contain at least on type whose
63+
* scalar type is `fvar<T>`
64+
* @param func Functor for which fvar<T> support is needed
65+
* @param args Parameter pack of arguments to be passed to functor.
66+
*/
67+
template <typename F, typename... TArgs,
68+
require_any_st_fvar<TArgs...>* = nullptr>
69+
inline auto finite_diff(const F& func, const TArgs&... args) {
70+
using FvarT = return_type_t<TArgs...>;
71+
using FvarInnerT = typename FvarT::Scalar;
72+
73+
std::vector<FvarInnerT> serialised_args
74+
= serialize<FvarInnerT>(value_of(args)...);
75+
76+
auto serial_functor = [&](const auto& v) {
77+
auto v_deserializer = to_deserializer(v);
78+
return func(v_deserializer.read(args)...);
79+
};
80+
81+
FvarInnerT rtn_value;
82+
std::vector<FvarInnerT> grad;
83+
finite_diff_gradient_auto(serial_functor, serialised_args, rtn_value, grad);
84+
85+
FvarInnerT rtn_grad = 0;
86+
auto grad_deserializer = to_deserializer(grad);
87+
// Use a fold-expression to aggregate tangents for input arguments
88+
static_cast<void>(
89+
std::initializer_list<int>{(rtn_grad += internal::aggregate_tangent(
90+
grad_deserializer.read(args), args),
91+
0)...});
92+
93+
return FvarT(rtn_value, rtn_grad);
94+
}
95+
96+
/**
97+
* Construct an fvar<T> where the tangent is calculated by finite-differencing.
98+
* Finite-differencing is only perfomed where the scalar type to be evaluated is
99+
* `fvar<T>.
100+
*
101+
* This overload is used when no fvar<T> arguments are passed and simply
102+
* evaluates the functor with the provided arguments.
103+
*
104+
* @tparam F Type of functor
105+
* @tparam TArgs Template parameter pack of the types passed in the `operator()`
106+
* of the functor type `F`. Must contain no type whose
107+
* scalar type is `fvar<T>`
108+
* @param func Functor
109+
* @param args... Parameter pack of arguments to be passed to functor.
110+
*/
111+
template <typename F, typename... TArgs,
112+
require_all_not_st_fvar<TArgs...>* = nullptr>
113+
inline auto finite_diff(const F& func, const TArgs&... args) {
114+
return func(args...);
115+
}
116+
117+
} // namespace math
118+
} // namespace stan
119+
120+
#endif
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#ifndef STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP
2+
#define STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP
3+
4+
#include <stan/math/fwd/meta.hpp>
5+
#include <stan/math/prim/functor/integrate_1d.hpp>
6+
#include <stan/math/prim/fun/value_of.hpp>
7+
#include <stan/math/prim/meta/forward_as.hpp>
8+
#include <stan/math/prim/functor/apply.hpp>
9+
#include <stan/math/fwd/functor/finite_diff.hpp>
10+
11+
namespace stan {
12+
namespace math {
13+
/**
14+
* Return the integral of f from a to b to the given relative tolerance
15+
*
16+
* @tparam F Type of f
17+
* @tparam T_a type of first limit
18+
* @tparam T_b type of second limit
19+
* @tparam Args types of parameter pack arguments
20+
*
21+
* @param f the functor to integrate
22+
* @param a lower limit of integration
23+
* @param b upper limit of integration
24+
* @param relative_tolerance relative tolerance passed to Boost quadrature
25+
* @param[in, out] msgs the print stream for warning messages
26+
* @param args additional arguments to pass to f
27+
* @return numeric integral of function f
28+
*/
29+
template <typename F, typename T_a, typename T_b, typename... Args,
30+
require_any_st_fvar<T_a, T_b, Args...> * = nullptr>
31+
inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
32+
const F &f, const T_a &a, const T_b &b, double relative_tolerance,
33+
std::ostream *msgs, const Args &... args) {
34+
using FvarT = scalar_type_t<return_type_t<T_a, T_b, Args...>>;
35+
36+
// Wrap integrate_1d call in a functor where the input arguments are only
37+
// for which tangents are needed
38+
auto a_val = value_of(a);
39+
auto b_val = value_of(b);
40+
auto func
41+
= [f, msgs, relative_tolerance, a_val, b_val](const auto &... args_var) {
42+
return integrate_1d_impl(f, a_val, b_val, relative_tolerance, msgs,
43+
args_var...);
44+
};
45+
FvarT ret = finite_diff(func, args...);
46+
47+
// Calculate tangents w.r.t. integration bounds if needed
48+
if (is_fvar<T_a>::value || is_fvar<T_b>::value) {
49+
auto val_args = std::make_tuple(value_of(args)...);
50+
if (is_fvar<T_a>::value) {
51+
ret.d_ += math::forward_as<FvarT>(a).d_
52+
* math::apply(
53+
[&](auto &&... tuple_args) {
54+
return -f(a_val, 0.0, msgs, tuple_args...);
55+
},
56+
val_args);
57+
}
58+
if (is_fvar<T_b>::value) {
59+
ret.d_ += math::forward_as<FvarT>(b).d_
60+
* math::apply(
61+
[&](auto &&... tuple_args) {
62+
return f(b_val, 0.0, msgs, tuple_args...);
63+
},
64+
val_args);
65+
}
66+
}
67+
return ret;
68+
}
69+
70+
/**
71+
* Compute the integral of the single variable function f from a to b to within
72+
* a specified relative tolerance. a and b can be finite or infinite.
73+
*
74+
* @tparam T_a type of first limit
75+
* @tparam T_b type of second limit
76+
* @tparam T_theta type of parameters
77+
* @tparam T Type of f
78+
*
79+
* @param f the functor to integrate
80+
* @param a lower limit of integration
81+
* @param b upper limit of integration
82+
* @param theta additional parameters to be passed to f
83+
* @param x_r additional data to be passed to f
84+
* @param x_i additional integer data to be passed to f
85+
* @param[in, out] msgs the print stream for warning messages
86+
* @param relative_tolerance relative tolerance passed to Boost quadrature
87+
* @return numeric integral of function f
88+
*/
89+
template <typename F, typename T_a, typename T_b, typename T_theta,
90+
require_any_fvar_t<T_a, T_b, T_theta> * = nullptr>
91+
inline return_type_t<T_a, T_b, T_theta> integrate_1d(
92+
const F &f, const T_a &a, const T_b &b, const std::vector<T_theta> &theta,
93+
const std::vector<double> &x_r, const std::vector<int> &x_i,
94+
std::ostream *msgs, const double relative_tolerance) {
95+
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance,
96+
msgs, theta, x_r, x_i);
97+
}
98+
99+
} // namespace math
100+
} // namespace stan
101+
#endif

stan/math/prim/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@
311311
#include <stan/math/prim/fun/scaled_add.hpp>
312312
#include <stan/math/prim/fun/sd.hpp>
313313
#include <stan/math/prim/fun/segment.hpp>
314+
#include <stan/math/prim/fun/serializer.hpp>
314315
#include <stan/math/prim/fun/select.hpp>
315316
#include <stan/math/prim/fun/sign.hpp>
316317
#include <stan/math/prim/fun/signbit.hpp>
Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
#ifndef TEST_UNIT_MATH_SERIALIZER_HPP
2-
#define TEST_UNIT_MATH_SERIALIZER_HPP
1+
#ifndef STAN_MATH_PRIM_FUN_SERIALIZER_HPP
2+
#define STAN_MATH_PRIM_FUN_SERIALIZER_HPP
33

4-
#include <stan/math.hpp>
4+
#include <stan/math/prim/meta/promote_scalar_type.hpp>
5+
#include <stan/math/prim/fun/to_vector.hpp>
6+
#include <stan/math/prim/fun/to_array_1d.hpp>
57
#include <complex>
68
#include <string>
79
#include <vector>
810

911
namespace stan {
10-
namespace test {
12+
namespace math {
1113

1214
/**
1315
* A class to store a sequence of values which can be deserialized
@@ -44,10 +46,10 @@ struct deserializer {
4446
/**
4547
* Construct a deserializer from the specified sequence of values.
4648
*
47-
* @param vals values to deserialize
49+
* @param v_vals values to deserialize
4850
*/
4951
explicit deserializer(const Eigen::Matrix<T, -1, 1>& v_vals)
50-
: position_(0), vals_(math::to_array_1d(v_vals)) {}
52+
: position_(0), vals_(to_array_1d(v_vals)) {}
5153

5254
/**
5355
* Read a scalar conforming to the shape of the specified argument,
@@ -94,8 +96,8 @@ struct deserializer {
9496
*/
9597
template <typename U, require_std_vector_t<U>* = nullptr,
9698
require_not_st_complex<U>* = nullptr>
97-
typename stan::math::promote_scalar_type<T, U>::type read(const U& x) {
98-
typename stan::math::promote_scalar_type<T, U>::type y;
99+
promote_scalar_t<T, U> read(const U& x) {
100+
promote_scalar_t<T, U> y;
99101
y.reserve(x.size());
100102
for (size_t i = 0; i < x.size(); ++i)
101103
y.push_back(read(x[i]));
@@ -113,9 +115,8 @@ struct deserializer {
113115
* @return deserialized value with shape and size matching argument
114116
*/
115117
template <typename U, require_std_vector_st<is_complex, U>* = nullptr>
116-
typename stan::math::promote_scalar_type<std::complex<T>, U>::type read(
117-
const U& x) {
118-
typename stan::math::promote_scalar_type<std::complex<T>, U>::type y;
118+
promote_scalar_t<std::complex<T>, U> read(const U& x) {
119+
promote_scalar_t<std::complex<T>, U> y;
119120
y.reserve(x.size());
120121
for (size_t i = 0; i < x.size(); ++i)
121122
y.push_back(read(x[i]));
@@ -257,9 +258,7 @@ struct serializer {
257258
*
258259
* @return serialized values
259260
*/
260-
const Eigen::Matrix<T, -1, 1>& vector_vals() {
261-
return math::to_vector(vals_);
262-
}
261+
const Eigen::Matrix<T, -1, 1>& vector_vals() { return to_vector(vals_); }
263262
};
264263

265264
/**
@@ -338,10 +337,10 @@ std::vector<real_return_t<T>> serialize_return(const T& x) {
338337
*/
339338
template <typename... Ts>
340339
Eigen::VectorXd serialize_args(const Ts... xs) {
341-
return math::to_vector(serialize<double>(xs...));
340+
return to_vector(serialize<double>(xs...));
342341
}
343342

344-
} // namespace test
343+
} // namespace math
345344
} // namespace stan
346345

347346
#endif

stan/math/prim/functor/finite_diff_gradient_auto.hpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,39 +46,40 @@ namespace math {
4646
* @param[out] fx function applied to argument
4747
* @param[out] grad_fx gradient of function at argument
4848
*/
49-
template <typename F>
50-
void finite_diff_gradient_auto(const F& f, const Eigen::VectorXd& x, double& fx,
51-
Eigen::VectorXd& grad_fx) {
52-
Eigen::VectorXd x_temp(x);
49+
template <typename F, typename VectorT,
50+
typename ScalarT = return_type_t<VectorT>>
51+
void finite_diff_gradient_auto(const F& f, const VectorT& x, ScalarT& fx,
52+
VectorT& grad_fx) {
53+
VectorT x_temp(x);
5354
fx = f(x);
5455
grad_fx.resize(x.size());
5556
for (int i = 0; i < x.size(); ++i) {
56-
double h = finite_diff_stepsize(x(i));
57+
double h = finite_diff_stepsize(value_of_rec(x[i]));
5758

58-
double delta_f = 0;
59+
ScalarT delta_f = 0;
5960

60-
x_temp(i) = x(i) + 3 * h;
61+
x_temp[i] = x[i] + 3 * h;
6162
delta_f += f(x_temp);
6263

63-
x_temp(i) = x(i) + 2 * h;
64+
x_temp[i] = x[i] + 2 * h;
6465
delta_f -= 9 * f(x_temp);
6566

66-
x_temp(i) = x(i) + h;
67+
x_temp[i] = x[i] + h;
6768
delta_f += 45 * f(x_temp);
6869

69-
x_temp(i) = x(i) + -3 * h;
70+
x_temp[i] = x[i] + -3 * h;
7071
delta_f -= f(x_temp);
7172

72-
x_temp(i) = x(i) + -2 * h;
73+
x_temp[i] = x[i] + -2 * h;
7374
delta_f += 9 * f(x_temp);
7475

75-
x_temp(i) = x(i) - h;
76+
x_temp[i] = x[i] - h;
7677
delta_f -= 45 * f(x_temp);
7778

7879
delta_f /= 60 * h;
7980

80-
x_temp(i) = x(i);
81-
grad_fx(i) = delta_f;
81+
x_temp[i] = x[i];
82+
grad_fx[i] = delta_f;
8283
}
8384
}
8485

stan/math/prim/functor/integrate_1d.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ inline double integrate(const F& f, double a, double b,
171171
* @return numeric integral of function f
172172
*/
173173
template <typename F, typename... Args,
174-
require_all_not_st_var<Args...>* = nullptr>
174+
require_all_st_arithmetic<Args...>* = nullptr>
175175
inline double integrate_1d_impl(const F& f, double a, double b,
176176
double relative_tolerance, std::ostream* msgs,
177177
const Args&... args) {

0 commit comments

Comments
 (0)