Skip to content

Commit 1536d12

Browse files
committed
Optimise deserializer use, generalise finite_diff input to remove copy
1 parent 4209cc9 commit 1536d12

2 files changed

Lines changed: 35 additions & 30 deletions

File tree

stan/math/fwd/functor/fvar_finite_diff.hpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ namespace internal {
1717
*
1818
* Overload for when the input is not an fvar<T> and no tangents are needed.
1919
*
20-
* @tparam FuncTangentT Type of tangent calculated by finite-differences
21-
* @tparam InputArgT Type of the function input argument
20+
* @tparam FuncTangent Type of tangent calculated by finite-differences
21+
* @tparam InputArg Type of the function input argument
2222
* @param tangent Calculated tangent
2323
* @param arg Input argument
2424
*/
25-
template <typename FuncTangentT, typename InputArgT,
26-
require_not_st_fvar<InputArgT>* = nullptr>
27-
inline constexpr double aggregate_tangent(const FuncTangentT& tangent,
28-
const InputArgT& arg) {
25+
template <typename FuncTangent, typename InputArg,
26+
require_not_st_fvar<InputArg>* = nullptr>
27+
inline constexpr double aggregate_tangent(const FuncTangent& tangent,
28+
const InputArg& arg) {
2929
return 0;
3030
}
3131

@@ -36,14 +36,14 @@ inline constexpr double aggregate_tangent(const FuncTangentT& tangent,
3636
* Overload for when the input is an fvar<T> and its tangent needs to be
3737
* aggregated.
3838
*
39-
* @tparam FuncTangentT Type of tangent calculated by finite-differences
40-
* @tparam InputArgT Type of the function input argument
39+
* @tparam FuncTangent Type of tangent calculated by finite-differences
40+
* @tparam InputArg Type of the function input argument
4141
* @param tangent Calculated tangent
4242
* @param arg Input argument
4343
*/
44-
template <typename FuncTangentT, typename InputArgT,
45-
require_st_fvar<InputArgT>* = nullptr>
46-
auto aggregate_tangent(const FuncTangentT& tangent, const InputArgT& arg) {
44+
template <typename FuncTangent, typename InputArg,
45+
require_st_fvar<InputArg>* = nullptr>
46+
auto aggregate_tangent(const FuncTangent& tangent, const InputArg& arg) {
4747
return sum(apply_scalar_binary(
4848
tangent, arg, [](const auto& x, const auto& y) { return x * y.d_; }));
4949
}
@@ -66,23 +66,27 @@ auto fvar_finite_diff(const F& func, const TArgs&... args) {
6666
using FvarT = return_type_t<TArgs...>;
6767
using FvarInnerT = typename FvarT::Scalar;
6868

69-
auto serialised_args = serialize<FvarInnerT>(value_of(args)...);
69+
std::vector<FvarInnerT> serialised_args
70+
= serialize<FvarInnerT>(value_of(args)...);
7071

7172
// Create a 'wrapper' functor which will take the flattened column-vector
7273
// and transform it to individual arguments which are passed to the
7374
// user-provided functor
7475
auto serial_functor
75-
= [&](const auto& v) { return func(to_deserializer(v).read(args)...); };
76+
= [&](const auto& v) {
77+
auto v_deserializer = to_deserializer(v);
78+
return func(v_deserializer.read(args)...);
79+
};
7680

7781
FvarInnerT rtn_value;
78-
Eigen::Matrix<FvarInnerT, -1, 1> grad;
79-
finite_diff_gradient_auto(serial_functor, to_vector(serialised_args),
80-
rtn_value, grad);
82+
std::vector<FvarInnerT> grad;
83+
finite_diff_gradient_auto(serial_functor, serialised_args, rtn_value, grad);
8184

8285
FvarInnerT rtn_grad = 0;
86+
auto grad_deserializer = to_deserializer(grad);
8387
// Use a fold-expression to aggregate tangents for input arguments
8488
(void)std::initializer_list<int>{(rtn_grad += internal::aggregate_tangent(
85-
to_deserializer(grad).read(args), args),
89+
grad_deserializer.read(args), args),
8690
0)...};
8791

8892
return FvarT(rtn_value, rtn_grad);

stan/math/prim/functor/finite_diff_gradient_auto.hpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,41 +46,42 @@ namespace math {
4646
* @param[out] fx function applied to argument
4747
* @param[out] grad_fx gradient of function at argument
4848
*/
49-
template <typename F, typename ScalarT>
49+
template <typename F, typename VectorT,
50+
typename ScalarT = return_type_t<VectorT>>
5051
void finite_diff_gradient_auto(const F& f,
51-
const Eigen::Matrix<ScalarT, -1, 1>& x,
52+
const VectorT& x,
5253
ScalarT& fx,
53-
Eigen::Matrix<ScalarT, -1, 1>& grad_fx) {
54-
Eigen::Matrix<ScalarT, -1, 1> x_temp(x);
54+
VectorT& grad_fx) {
55+
VectorT x_temp(x);
5556
fx = f(x);
5657
grad_fx.resize(x.size());
5758
for (int i = 0; i < x.size(); ++i) {
58-
double h = finite_diff_stepsize(value_of_rec(x(i)));
59+
double h = finite_diff_stepsize(value_of_rec(x[i]));
5960

6061
ScalarT delta_f = 0;
6162

62-
x_temp(i) = x(i) + 3 * h;
63+
x_temp[i] = x[i] + 3 * h;
6364
delta_f += f(x_temp);
6465

65-
x_temp(i) = x(i) + 2 * h;
66+
x_temp[i] = x[i] + 2 * h;
6667
delta_f -= 9 * f(x_temp);
6768

68-
x_temp(i) = x(i) + h;
69+
x_temp[i] = x[i] + h;
6970
delta_f += 45 * f(x_temp);
7071

71-
x_temp(i) = x(i) + -3 * h;
72+
x_temp[i] = x[i] + -3 * h;
7273
delta_f -= f(x_temp);
7374

74-
x_temp(i) = x(i) + -2 * h;
75+
x_temp[i] = x[i] + -2 * h;
7576
delta_f += 9 * f(x_temp);
7677

77-
x_temp(i) = x(i) - h;
78+
x_temp[i] = x[i] - h;
7879
delta_f -= 45 * f(x_temp);
7980

8081
delta_f /= 60 * h;
8182

82-
x_temp(i) = x(i);
83-
grad_fx(i) = delta_f;
83+
x_temp[i] = x[i];
84+
grad_fx[i] = delta_f;
8485
}
8586
}
8687

0 commit comments

Comments
 (0)