|
2 | 2 | #define STAN_MATH_FWD_FUNCTOR_FVAR_FINITE_DIFF_HPP |
3 | 3 |
|
4 | 4 | #include <stan/math/prim/meta.hpp> |
5 | | -#include <stan/math/prim/functor/apply.hpp> |
6 | 5 | #include <stan/math/prim/functor/finite_diff_gradient_auto.hpp> |
7 | 6 | #include <stan/math/prim/fun/value_of.hpp> |
8 | 7 | #include <stan/math/prim/fun/sum.hpp> |
@@ -66,34 +65,25 @@ auto fvar_finite_diff(const F& func, const TArgs&... args) { |
66 | 65 | using FvarT = return_type_t<TArgs...>; |
67 | 66 | using FvarInnerT = typename FvarT::Scalar; |
68 | 67 |
|
69 | | - auto val_args = std::make_tuple(stan::math::value_of(args)...); |
70 | | - |
71 | | - auto serialised_args = stan::math::apply( |
72 | | - [&](auto&&... tuple_args) { |
73 | | - return math::to_vector( |
74 | | - stan::test::serialize<FvarInnerT>(tuple_args...)); |
75 | | - }, |
76 | | - val_args); |
| 68 | + auto serialised_args = test::serialize<FvarInnerT>(value_of(args)...); |
77 | 69 |
|
78 | 70 | // Create a 'wrapper' functor which will take the flattened column-vector |
79 | 71 | // and transform it to individual arguments which are passed to the |
80 | 72 | // user-provided functor |
81 | 73 | auto serial_functor = [&](const auto& v) { |
82 | | - auto ds = stan::test::to_deserializer(v); |
83 | | - return stan::math::apply( |
84 | | - [&](auto&&... tuple_args) { return func(ds.read(tuple_args)...); }, |
85 | | - val_args); |
| 74 | + return func(test::to_deserializer(v).read(args)...); |
86 | 75 | }; |
87 | 76 |
|
88 | 77 | FvarInnerT rtn_value; |
89 | 78 | Eigen::Matrix<FvarInnerT, -1, 1> grad; |
90 | | - finite_diff_gradient_auto(serial_functor, serialised_args, rtn_value, grad); |
| 79 | + finite_diff_gradient_auto(serial_functor, to_vector(serialised_args), |
| 80 | + rtn_value, grad); |
91 | 81 |
|
92 | | - auto ds_grad = stan::test::to_deserializer(grad); |
93 | 82 | FvarInnerT rtn_grad = 0; |
94 | 83 | // Use a fold-expression to aggregate tangents for input arguments |
95 | 84 | (void)std::initializer_list<int>{( |
96 | | - rtn_grad += internal::aggregate_tangent(ds_grad.read(args), args), 0)...}; |
| 85 | + rtn_grad += internal::aggregate_tangent( |
| 86 | + test::to_deserializer(grad).read(args), args), 0)...}; |
97 | 87 |
|
98 | 88 | return FvarT(rtn_value, rtn_grad); |
99 | 89 | } |
|
0 commit comments