Skip to content

Commit f8d886d

Browse files
committed
Update doc & naming
1 parent 8037cd5 commit f8d886d

3 files changed

Lines changed: 25 additions & 22 deletions

File tree

stan/math/fwd/functor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include <stan/math/fwd/functor/apply_scalar_unary.hpp>
55
#include <stan/math/fwd/functor/gradient.hpp>
6-
#include <stan/math/fwd/functor/fvar_finite_diff.hpp>
6+
#include <stan/math/fwd/functor/finite_diff.hpp>
77
#include <stan/math/fwd/functor/hessian.hpp>
88
#include <stan/math/fwd/functor/integrate_1d.hpp>
99
#include <stan/math/fwd/functor/jacobian.hpp>

stan/math/fwd/functor/fvar_finite_diff.hpp renamed to stan/math/fwd/functor/finite_diff.hpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef STAN_MATH_FWD_FUNCTOR_FVAR_FINITE_DIFF_HPP
2-
#define STAN_MATH_FWD_FUNCTOR_FVAR_FINITE_DIFF_HPP
1+
#ifndef STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP
2+
#define STAN_MATH_FWD_FUNCTOR_FINITE_DIFF_HPP
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
@@ -43,35 +43,36 @@ inline constexpr double aggregate_tangent(const FuncTangent& tangent,
4343
*/
4444
template <typename FuncTangent, typename InputArg,
4545
require_st_fvar<InputArg>* = nullptr>
46-
auto aggregate_tangent(const FuncTangent& tangent, const InputArg& arg) {
46+
inline auto aggregate_tangent(const FuncTangent& tangent, const InputArg& arg) {
4747
return sum(apply_scalar_binary(
4848
tangent, arg, [](const auto& x, const auto& y) { return x * y.d_; }));
4949
}
5050
} // namespace internal
5151

5252
/**
53-
* This frameworks adds fvar<T> support for arbitrary functions through
54-
* finite-differencing. Higher-order inputs (i.e., fvar<var> & fvar<fvar<T>>)
55-
* are also implicitly supported.
53+
* Construct an fvar<T> where the tangent is calculated by finite-differencing.
54+
* Finite-differencing is only perfomed where the scalar type to be evaluated is
55+
* `fvar<T>.
56+
*
57+
* Higher-order inputs (i.e., fvar<var> & fvar<fvar<T>>) are also implicitly
58+
* supported through auto-diffing the finite-differencing process.
5659
*
5760
* @tparam F Type of functor for which fvar<T> support is needed
58-
* @tparam TArgs... Types of arguments (containing at least one fvar<T> type)
59-
* to be passed to function
61+
* @tparam TArgs Template parameter pack of the types passed in the `operator()`
62+
* of the functor type `F`. Must contain at least on type whose
63+
* scalar type is `fvar<T>`
6064
* @param func Functor for which fvar<T> support is needed
6165
* @param args Parameter pack of arguments to be passed to functor.
6266
*/
6367
template <typename F, typename... TArgs,
6468
require_any_st_fvar<TArgs...>* = nullptr>
65-
auto fvar_finite_diff(const F& func, const TArgs&... args) {
69+
inline auto finite_diff(const F& func, const TArgs&... args) {
6670
using FvarT = return_type_t<TArgs...>;
6771
using FvarInnerT = typename FvarT::Scalar;
6872

6973
std::vector<FvarInnerT> serialised_args
7074
= serialize<FvarInnerT>(value_of(args)...);
7175

72-
// Create a 'wrapper' functor which will take the flattened column-vector
73-
// and transform it to individual arguments which are passed to the
74-
// user-provided functor
7576
auto serial_functor = [&](const auto& v) {
7677
auto v_deserializer = to_deserializer(v);
7778
return func(v_deserializer.read(args)...);
@@ -93,21 +94,23 @@ auto fvar_finite_diff(const F& func, const TArgs&... args) {
9394
}
9495

9596
/**
96-
* This frameworks adds fvar<T> support for arbitrary functions through
97-
* finite-differencing. Higher-order inputs (i.e., fvar<var> & fvar<fvar<T>>)
98-
* are also implicitly supported.
97+
* Construct an fvar<T> where the tangent is calculated by finite-differencing.
98+
* Finite-differencing is only perfomed where the scalar type to be evaluated is
99+
* `fvar<T>.
99100
*
100-
* Overload for use when no fvar<T> arguments are passed, and finite-differences
101-
* are not needed.
101+
* This overload is used when no fvar<T> arguments are passed and simply
102+
* evaluates the functor with the provided arguments.
102103
*
103104
* @tparam F Type of functor
104-
* @tparam TArgs... Types of arguments (containing no fvar<T> types)
105+
* @tparam TArgs Template parameter pack of the types passed in the `operator()`
106+
* of the functor type `F`. Must contain no type whose
107+
* scalar type is `fvar<T>`
105108
* @param func Functor
106109
* @param args... Parameter pack of arguments to be passed to functor.
107110
*/
108111
template <typename F, typename... TArgs,
109112
require_all_not_st_fvar<TArgs...>* = nullptr>
110-
auto fvar_finite_diff(const F& func, const TArgs&... args) {
113+
inline auto finite_diff(const F& func, const TArgs&... args) {
111114
return func(args...);
112115
}
113116

stan/math/fwd/functor/integrate_1d.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <stan/math/prim/fun/value_of.hpp>
77
#include <stan/math/prim/meta/forward_as.hpp>
88
#include <stan/math/prim/functor/apply.hpp>
9-
#include <stan/math/fwd/functor/fvar_finite_diff.hpp>
9+
#include <stan/math/fwd/functor/finite_diff.hpp>
1010

1111
namespace stan {
1212
namespace math {
@@ -42,7 +42,7 @@ inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
4242
return integrate_1d_impl(f, a_val, b_val, relative_tolerance, msgs,
4343
args_var...);
4444
};
45-
FvarT ret = fvar_finite_diff(func, args...);
45+
FvarT ret = finite_diff(func, args...);
4646

4747
// Calculate tangents w.r.t. integration bounds if needed
4848
if (is_fvar<T_a>::value || is_fvar<T_b>::value) {

0 commit comments

Comments
 (0)