Skip to content

Commit 0d00dc0

Browse files
committed
Add necessary overloads for Tuple return types in expression tests
1 parent e45480c commit 0d00dc0

1 file changed

Lines changed: 17 additions & 1 deletion

File tree

test/expressions/expression_test_helpers.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <gtest/gtest.h>
22
#include <stan/math/prim.hpp>
3+
#include <stan/math/prim/meta.hpp>
4+
#include <stan/math/prim/functor/for_each.hpp>
35
#include <stan/math/rev.hpp>
46
#include <stan/math/fwd.hpp>
57
#include <vector>
@@ -18,7 +20,8 @@ struct counterOp {
1820
}
1921
};
2022

21-
template <typename T>
23+
24+
template <typename T, stan::math::require_not_tuple_t<T>* = nullptr>
2225
auto recursive_sum(const T& a) {
2326
return math::sum(a);
2427
}
@@ -32,6 +35,13 @@ auto recursive_sum(const std::vector<T>& a) {
3235
return res;
3336
}
3437

38+
template <typename T, stan::math::require_tuple_t<T>* = nullptr>
39+
auto recursive_sum(const T& t1) {
40+
stan::value_type_t<decltype(std::get<0>(t1))> val = 0;
41+
stan::math::for_each([&val](auto&& elt1) { val += recursive_sum(elt1); }, t1);
42+
return val;
43+
}
44+
3545
template <typename T, require_integral_t<T>* = nullptr>
3646
T make_arg(double value = 0.4, int size = 1) {
3747
return 1;
@@ -160,6 +170,12 @@ void expect_eq(const std::vector<T>& a, const std::vector<T>& b,
160170
}
161171
}
162172

173+
template <typename T, stan::math::require_tuple_t<T>* = nullptr>
174+
void expect_eq(const T& t1, const T& t2, const char* msg) {
175+
stan::math::for_each(
176+
[&msg](auto&& elt1, auto&& elt2) { expect_eq(elt1, elt2, msg); }, t1, t2);
177+
}
178+
163179
template <typename T, require_not_st_var<T>* = nullptr>
164180
void expect_adj_eq(const T& a, const T& b, const char* msg = "expect_ad_eq") {}
165181

0 commit comments

Comments
 (0)