Skip to content

Commit 85b5472

Browse files
nsicchaclaude
andcommitted
Add tuple overloads for deep_copy_vars, save_varis, accumulate_adjoints
These three functions handle var, vector<var>, Eigen<var>, and arithmetic types but not tuples. With STAN_THREADS=true, reduce_sum passes tuple arguments through these functions, causing compilation failures. Add tuple overloads that unpack the tuple via stan::math::apply and recursively process each element. Fixes #3041 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 23dafa2 commit 85b5472

3 files changed

Lines changed: 57 additions & 0 deletions

File tree

stan/math/rev/core/accumulate_adjoints.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
#define STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP
33

44
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/functor/apply.hpp>
56
#include <stan/math/rev/meta.hpp>
67
#include <stan/math/rev/core/var.hpp>
78

9+
#include <tuple>
810
#include <utility>
911
#include <vector>
1012

@@ -35,6 +37,9 @@ inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args);
3537

3638
inline double* accumulate_adjoints(double* dest);
3739

40+
template <typename Tuple, require_tuple_t<Tuple>* = nullptr, typename... Pargs>
41+
inline double* accumulate_adjoints(double* dest, Tuple&& x, Pargs&&... args);
42+
3843
/**
3944
* Accumulate adjoints from x into storage pointed to by dest,
4045
* increment the adjoint storage pointer,
@@ -147,6 +152,20 @@ inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args) {
147152
*/
148153
inline double* accumulate_adjoints(double* dest) { return dest; }
149154

155+
/**
156+
* Unpack a tuple and accumulate adjoints from each element.
157+
*/
158+
template <typename Tuple, require_tuple_t<Tuple>* = nullptr, typename... Pargs>
159+
inline double* accumulate_adjoints(double* dest, Tuple&& x, Pargs&&... args) {
160+
dest = stan::math::apply(
161+
[dest](auto&&... inner_args) {
162+
return accumulate_adjoints(
163+
dest, std::forward<decltype(inner_args)>(inner_args)...);
164+
},
165+
std::forward<Tuple>(x));
166+
return accumulate_adjoints(dest, std::forward<Pargs>(args)...);
167+
}
168+
150169
} // namespace math
151170
} // namespace stan
152171

stan/math/rev/core/deep_copy_vars.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
#define STAN_MATH_REV_CORE_DEEP_COPY_VARS_HPP
33

44
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/functor/apply.hpp>
56
#include <stan/math/rev/meta.hpp>
67
#include <stan/math/rev/core/var.hpp>
78

9+
#include <tuple>
810
#include <utility>
911
#include <vector>
1012

@@ -81,6 +83,24 @@ inline auto deep_copy_vars(EigT&& arg) {
8183
.eval();
8284
}
8385

86+
/**
87+
* Copy the vars in a tuple but reallocate new varis for them.
88+
* Non-var elements are forwarded unchanged.
89+
*
90+
* @tparam Tuple A std::tuple type
91+
* @param arg A tuple potentially containing vars
92+
* @return A new tuple with deep-copied vars
93+
*/
94+
template <typename Tuple, require_tuple_t<Tuple>* = nullptr>
95+
inline auto deep_copy_vars(Tuple&& arg) {
96+
return stan::math::apply(
97+
[](auto&&... args) {
98+
return std::make_tuple(deep_copy_vars(
99+
std::forward<decltype(args)>(args))...);
100+
},
101+
std::forward<Tuple>(arg));
102+
}
103+
84104
} // namespace math
85105
} // namespace stan
86106

stan/math/rev/core/save_varis.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
#include <stan/math/prim/fun/Eigen.hpp>
55
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/functor/apply.hpp>
67
#include <stan/math/rev/meta.hpp>
78
#include <stan/math/rev/core/var.hpp>
89

10+
#include <tuple>
911
#include <utility>
1012
#include <vector>
1113

@@ -35,6 +37,9 @@ inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args);
3537

3638
inline vari** save_varis(vari** dest);
3739

40+
template <typename Tuple, require_tuple_t<Tuple>* = nullptr, typename... Pargs>
41+
inline vari** save_varis(vari** dest, Tuple&& x, Pargs&&... args);
42+
3843
/**
3944
* Save the vari pointer in x into the memory pointed to by dest,
4045
* increment the dest storage pointer,
@@ -143,6 +148,19 @@ inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args) {
143148
*/
144149
inline vari** save_varis(vari** dest) { return dest; }
145150

151+
/**
152+
* Unpack a tuple and save the varis of each element.
153+
*/
154+
template <typename Tuple, require_tuple_t<Tuple>* = nullptr, typename... Pargs>
155+
inline vari** save_varis(vari** dest, Tuple&& x, Pargs&&... args) {
156+
dest = stan::math::apply(
157+
[dest](auto&&... inner_args) {
158+
return save_varis(dest, std::forward<decltype(inner_args)>(inner_args)...);
159+
},
160+
std::forward<Tuple>(x));
161+
return save_varis(dest, std::forward<Pargs>(args)...);
162+
}
163+
146164
} // namespace math
147165
} // namespace stan
148166

0 commit comments

Comments
 (0)