|
| 1 | +#ifndef STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP |
| 2 | +#define STAN_MATH_FWD_FUNCTOR_INTEGRATE_1D_HPP |
| 3 | + |
| 4 | +#include <stan/math/fwd/meta.hpp> |
| 5 | +#include <stan/math/prim/functor/integrate_1d.hpp> |
| 6 | +#include <stan/math/prim/fun/value_of.hpp> |
| 7 | +#include <stan/math/prim/meta/forward_as.hpp> |
| 8 | +#include <stan/math/prim/functor/apply.hpp> |
| 9 | +#include <stan/math/fwd/functor/finite_diff.hpp> |
| 10 | + |
| 11 | +namespace stan { |
| 12 | +namespace math { |
| 13 | +/** |
| 14 | + * Return the integral of f from a to b to the given relative tolerance |
| 15 | + * |
| 16 | + * @tparam F Type of f |
| 17 | + * @tparam T_a type of first limit |
| 18 | + * @tparam T_b type of second limit |
| 19 | + * @tparam Args types of parameter pack arguments |
| 20 | + * |
| 21 | + * @param f the functor to integrate |
| 22 | + * @param a lower limit of integration |
| 23 | + * @param b upper limit of integration |
| 24 | + * @param relative_tolerance relative tolerance passed to Boost quadrature |
| 25 | + * @param[in, out] msgs the print stream for warning messages |
| 26 | + * @param args additional arguments to pass to f |
| 27 | + * @return numeric integral of function f |
| 28 | + */ |
| 29 | +template <typename F, typename T_a, typename T_b, typename... Args, |
| 30 | + require_any_st_fvar<T_a, T_b, Args...> * = nullptr> |
| 31 | +inline return_type_t<T_a, T_b, Args...> integrate_1d_impl( |
| 32 | + const F &f, const T_a &a, const T_b &b, double relative_tolerance, |
| 33 | + std::ostream *msgs, const Args &... args) { |
| 34 | + using FvarT = scalar_type_t<return_type_t<T_a, T_b, Args...>>; |
| 35 | + |
| 36 | + // Wrap integrate_1d call in a functor where the input arguments are only |
| 37 | + // for which tangents are needed |
| 38 | + auto a_val = value_of(a); |
| 39 | + auto b_val = value_of(b); |
| 40 | + auto func |
| 41 | + = [f, msgs, relative_tolerance, a_val, b_val](const auto &... args_var) { |
| 42 | + return integrate_1d_impl(f, a_val, b_val, relative_tolerance, msgs, |
| 43 | + args_var...); |
| 44 | + }; |
| 45 | + FvarT ret = finite_diff(func, args...); |
| 46 | + |
| 47 | + // Calculate tangents w.r.t. integration bounds if needed |
| 48 | + if (is_fvar<T_a>::value || is_fvar<T_b>::value) { |
| 49 | + auto val_args = std::make_tuple(value_of(args)...); |
| 50 | + if (is_fvar<T_a>::value) { |
| 51 | + ret.d_ += math::forward_as<FvarT>(a).d_ |
| 52 | + * math::apply( |
| 53 | + [&](auto &&... tuple_args) { |
| 54 | + return -f(a_val, 0.0, msgs, tuple_args...); |
| 55 | + }, |
| 56 | + val_args); |
| 57 | + } |
| 58 | + if (is_fvar<T_b>::value) { |
| 59 | + ret.d_ += math::forward_as<FvarT>(b).d_ |
| 60 | + * math::apply( |
| 61 | + [&](auto &&... tuple_args) { |
| 62 | + return f(b_val, 0.0, msgs, tuple_args...); |
| 63 | + }, |
| 64 | + val_args); |
| 65 | + } |
| 66 | + } |
| 67 | + return ret; |
| 68 | +} |
| 69 | + |
| 70 | +/** |
| 71 | + * Compute the integral of the single variable function f from a to b to within |
| 72 | + * a specified relative tolerance. a and b can be finite or infinite. |
| 73 | + * |
| 74 | + * @tparam T_a type of first limit |
| 75 | + * @tparam T_b type of second limit |
| 76 | + * @tparam T_theta type of parameters |
| 77 | + * @tparam T Type of f |
| 78 | + * |
| 79 | + * @param f the functor to integrate |
| 80 | + * @param a lower limit of integration |
| 81 | + * @param b upper limit of integration |
| 82 | + * @param theta additional parameters to be passed to f |
| 83 | + * @param x_r additional data to be passed to f |
| 84 | + * @param x_i additional integer data to be passed to f |
| 85 | + * @param[in, out] msgs the print stream for warning messages |
| 86 | + * @param relative_tolerance relative tolerance passed to Boost quadrature |
| 87 | + * @return numeric integral of function f |
| 88 | + */ |
| 89 | +template <typename F, typename T_a, typename T_b, typename T_theta, |
| 90 | + require_any_fvar_t<T_a, T_b, T_theta> * = nullptr> |
| 91 | +inline return_type_t<T_a, T_b, T_theta> integrate_1d( |
| 92 | + const F &f, const T_a &a, const T_b &b, const std::vector<T_theta> &theta, |
| 93 | + const std::vector<double> &x_r, const std::vector<int> &x_i, |
| 94 | + std::ostream *msgs, const double relative_tolerance) { |
| 95 | + return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance, |
| 96 | + msgs, theta, x_r, x_i); |
| 97 | +} |
| 98 | + |
| 99 | +} // namespace math |
| 100 | +} // namespace stan |
| 101 | +#endif |
0 commit comments