@@ -17,15 +17,15 @@ namespace internal {
1717 *
1818 * Overload for when the input is not an fvar<T> and no tangents are needed.
1919 *
20- * @tparam FuncTangentT Type of tangent calculated by finite-differences
21- * @tparam InputArgT Type of the function input argument
20+ * @tparam FuncTangent Type of tangent calculated by finite-differences
21+ * @tparam InputArg Type of the function input argument
2222 * @param tangent Calculated tangent
2323 * @param arg Input argument
2424 */
25- template <typename FuncTangentT , typename InputArgT ,
26- require_not_st_fvar<InputArgT >* = nullptr >
27- inline constexpr double aggregate_tangent (const FuncTangentT & tangent,
28- const InputArgT & arg) {
25+ template <typename FuncTangent , typename InputArg ,
26+ require_not_st_fvar<InputArg >* = nullptr >
27+ inline constexpr double aggregate_tangent (const FuncTangent & tangent,
28+ const InputArg & arg) {
2929 return 0 ;
3030}
3131
@@ -36,14 +36,14 @@ inline constexpr double aggregate_tangent(const FuncTangentT& tangent,
3636 * Overload for when the input is an fvar<T> and its tangent needs to be
3737 * aggregated.
3838 *
39- * @tparam FuncTangentT Type of tangent calculated by finite-differences
40- * @tparam InputArgT Type of the function input argument
39+ * @tparam FuncTangent Type of tangent calculated by finite-differences
40+ * @tparam InputArg Type of the function input argument
4141 * @param tangent Calculated tangent
4242 * @param arg Input argument
4343 */
44- template <typename FuncTangentT , typename InputArgT ,
45- require_st_fvar<InputArgT >* = nullptr >
46- auto aggregate_tangent (const FuncTangentT & tangent, const InputArgT & arg) {
44+ template <typename FuncTangent , typename InputArg ,
45+ require_st_fvar<InputArg >* = nullptr >
46+ auto aggregate_tangent (const FuncTangent & tangent, const InputArg & arg) {
4747 return sum (apply_scalar_binary (
4848 tangent, arg, [](const auto & x, const auto & y) { return x * y.d_ ; }));
4949}
@@ -66,23 +66,27 @@ auto fvar_finite_diff(const F& func, const TArgs&... args) {
6666 using FvarT = return_type_t <TArgs...>;
6767 using FvarInnerT = typename FvarT::Scalar;
6868
69- auto serialised_args = serialize<FvarInnerT>(value_of (args)...);
69+ std::vector<FvarInnerT> serialised_args
70+ = serialize<FvarInnerT>(value_of (args)...);
7071
7172 // Create a 'wrapper' functor which will take the flattened column-vector
7273 // and transform it to individual arguments which are passed to the
7374 // user-provided functor
7475 auto serial_functor
75- = [&](const auto & v) { return func (to_deserializer (v).read (args)...); };
76+ = [&](const auto & v) {
77+ auto v_deserializer = to_deserializer (v);
78+ return func (v_deserializer.read (args)...);
79+ };
7680
7781 FvarInnerT rtn_value;
78- Eigen::Matrix<FvarInnerT, -1 , 1 > grad;
79- finite_diff_gradient_auto (serial_functor, to_vector (serialised_args),
80- rtn_value, grad);
82+ std::vector<FvarInnerT> grad;
83+ finite_diff_gradient_auto (serial_functor, serialised_args, rtn_value, grad);
8184
8285 FvarInnerT rtn_grad = 0 ;
86+ auto grad_deserializer = to_deserializer (grad);
8387 // Use a fold-expression to aggregate tangents for input arguments
8488 (void )std::initializer_list<int >{(rtn_grad += internal::aggregate_tangent (
85- to_deserializer (grad) .read (args), args),
89+ grad_deserializer .read (args), args),
8690 0 )...};
8791
8892 return FvarT (rtn_value, rtn_grad);
0 commit comments