Skip to content

Commit 77e4812

Browse files
committed
Simplify fvar framework
1 parent 3435d1b commit 77e4812

2 files changed

Lines changed: 11 additions & 21 deletions

File tree

stan/math/fwd/functor/fvar_finite_diff.hpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define STAN_MATH_FWD_FUNCTOR_FVAR_FINITE_DIFF_HPP
33

44
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/functor/apply.hpp>
65
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp>
76
#include <stan/math/prim/fun/value_of.hpp>
87
#include <stan/math/prim/fun/sum.hpp>
@@ -66,34 +65,25 @@ auto fvar_finite_diff(const F& func, const TArgs&... args) {
6665
using FvarT = return_type_t<TArgs...>;
6766
using FvarInnerT = typename FvarT::Scalar;
6867

69-
auto val_args = std::make_tuple(stan::math::value_of(args)...);
70-
71-
auto serialised_args = stan::math::apply(
72-
[&](auto&&... tuple_args) {
73-
return math::to_vector(
74-
stan::test::serialize<FvarInnerT>(tuple_args...));
75-
},
76-
val_args);
68+
auto serialised_args = test::serialize<FvarInnerT>(value_of(args)...);
7769

7870
// Create a 'wrapper' functor which will take the flattened column-vector
7971
// and transform it to individual arguments which are passed to the
8072
// user-provided functor
8173
auto serial_functor = [&](const auto& v) {
82-
auto ds = stan::test::to_deserializer(v);
83-
return stan::math::apply(
84-
[&](auto&&... tuple_args) { return func(ds.read(tuple_args)...); },
85-
val_args);
74+
return func(test::to_deserializer(v).read(args)...);
8675
};
8776

8877
FvarInnerT rtn_value;
8978
Eigen::Matrix<FvarInnerT, -1, 1> grad;
90-
finite_diff_gradient_auto(serial_functor, serialised_args, rtn_value, grad);
79+
finite_diff_gradient_auto(serial_functor, to_vector(serialised_args),
80+
rtn_value, grad);
9181

92-
auto ds_grad = stan::test::to_deserializer(grad);
9382
FvarInnerT rtn_grad = 0;
9483
// Use a fold-expression to aggregate tangents for input arguments
9584
(void)std::initializer_list<int>{(
96-
rtn_grad += internal::aggregate_tangent(ds_grad.read(args), args), 0)...};
85+
rtn_grad += internal::aggregate_tangent(
86+
test::to_deserializer(grad).read(args), args), 0)...};
9787

9888
return FvarT(rtn_value, rtn_grad);
9989
}
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
#include <test/unit/math/test_ad.hpp>
22

33
TEST(mixFunctor, integrate1D) {
4-
auto f = [&](const auto& x_input) {
4+
auto f = [&](const auto& x_input, const auto& lb, const auto& ub) {
55
auto func = [](const auto& x, const auto& xc, std::ostream* msgs,
66
const auto& theta) {
77
return stan::math::exp(theta * stan::math::cos(2 * 3.141593 * x)) + theta;
88
};
99
const double relative_tolerance = std::sqrt(stan::math::EPSILON);
1010
std::ostringstream* msgs = nullptr;
11-
return stan::math::integrate_1d_impl(func, 0, 1, relative_tolerance, msgs,
11+
return stan::math::integrate_1d_impl(func, lb, ub, relative_tolerance, msgs,
1212
x_input);
1313
};
14-
stan::test::expect_ad(f, 0.75);
15-
stan::test::expect_ad(f, 0.2);
16-
stan::test::expect_ad(f, stan::math::NOT_A_NUMBER);
14+
stan::test::expect_ad(f, 0.75, 0, 1);
15+
stan::test::expect_ad(f, 0.2, 0.2, 0.7);
16+
stan::test::expect_ad(f, stan::math::NOT_A_NUMBER, 0, 1);
1717
}

0 commit comments

Comments
 (0)