Skip to content

Commit 9251316

Browse files
authored
Merge pull request #2902 from stan-dev/expression-tests-tuples
Add necessary overloads for Tuple return types in expression tests
2 parents e45480c + b3afaed commit 9251316

1 file changed

Lines changed: 16 additions & 1 deletion

File tree

test/expressions/expression_test_helpers.hpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <gtest/gtest.h>
22
#include <stan/math/prim.hpp>
3+
#include <stan/math/prim/meta.hpp>
4+
#include <stan/math/prim/functor/for_each.hpp>
35
#include <stan/math/rev.hpp>
46
#include <stan/math/fwd.hpp>
57
#include <vector>
@@ -18,7 +20,7 @@ struct counterOp {
1820
}
1921
};
2022

21-
template <typename T>
23+
template <typename T, stan::math::require_not_tuple_t<T>* = nullptr>
2224
auto recursive_sum(const T& a) {
2325
return math::sum(a);
2426
}
@@ -32,6 +34,13 @@ auto recursive_sum(const std::vector<T>& a) {
3234
return res;
3335
}
3436

37+
template <typename T, stan::math::require_tuple_t<T>* = nullptr>
38+
auto recursive_sum(const T& t1) {
39+
stan::value_type_t<decltype(std::get<0>(t1))> val = 0;
40+
stan::math::for_each([&val](auto&& elt1) { val += recursive_sum(elt1); }, t1);
41+
return val;
42+
}
43+
3544
template <typename T, require_integral_t<T>* = nullptr>
3645
T make_arg(double value = 0.4, int size = 1) {
3746
return 1;
@@ -160,6 +169,12 @@ void expect_eq(const std::vector<T>& a, const std::vector<T>& b,
160169
}
161170
}
162171

172+
template <typename T, stan::math::require_tuple_t<T>* = nullptr>
173+
void expect_eq(const T& t1, const T& t2, const char* msg) {
174+
stan::math::for_each(
175+
[&msg](auto&& elt1, auto&& elt2) { expect_eq(elt1, elt2, msg); }, t1, t2);
176+
}
177+
163178
template <typename T, require_not_st_var<T>* = nullptr>
164179
void expect_adj_eq(const T& a, const T& b, const char* msg = "expect_ad_eq") {}
165180

0 commit comments

Comments
 (0)