Skip to content

Commit ac20682

Browse files
nsicchaclaude
andcommitted
Trim includes, add docs and tests for tuple overloads
Remove unnecessary <tuple> includes (already available via prim/meta.hpp). Add doxygen comments matching existing style. Add unit tests for tuple overloads of deep_copy_vars, save_varis, and accumulate_adjoints covering tuple<var, int>, tuple<var, double>, and tuple<var, var>. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 46eae9c commit ac20682

6 files changed

Lines changed: 134 additions & 12 deletions

File tree

stan/math/rev/core/accumulate_adjoints.hpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <stan/math/rev/meta.hpp>
77
#include <stan/math/rev/core/var.hpp>
88

9-
#include <tuple>
109
#include <utility>
1110
#include <vector>
1211

@@ -153,9 +152,17 @@ inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args) {
153152
inline double* accumulate_adjoints(double* dest) { return dest; }
154153

155154
/**
156-
* Unpack a tuple and accumulate adjoints from each element.
155+
* Accumulate adjoints from a tuple into storage pointed to by dest
156+
* by unpacking the tuple and recursively processing each element.
157+
*
158+
* @tparam Tuple A std::tuple type
159+
* @tparam Pargs Types of remaining arguments
160+
* @param dest Pointer to where adjoints are to be accumulated
161+
* @param x A tuple potentially containing vars
162+
* @param args Further args to accumulate over
163+
* @return Final position of adjoint storage pointer
157164
*/
158-
template <typename Tuple, require_tuple_t<Tuple>* = nullptr, typename... Pargs>
165+
template <typename Tuple, require_tuple_t<Tuple>*, typename... Pargs>
159166
inline double* accumulate_adjoints(double* dest, Tuple&& x, Pargs&&... args) {
160167
dest = stan::math::apply(
161168
[dest](auto&&... inner_args) {

stan/math/rev/core/deep_copy_vars.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <stan/math/rev/meta.hpp>
77
#include <stan/math/rev/core/var.hpp>
88

9-
#include <tuple>
109
#include <utility>
1110
#include <vector>
1211

@@ -84,8 +83,8 @@ inline auto deep_copy_vars(EigT&& arg) {
8483
}
8584

8685
/**
87-
* Copy the vars in a tuple but reallocate new varis for them.
88-
* Non-var elements are forwarded unchanged.
86+
* Deep copy vars in a tuple, reallocating new varis for var elements
87+
* and forwarding non-var elements unchanged.
8988
*
9089
* @tparam Tuple A std::tuple type
9190
* @param arg A tuple potentially containing vars
@@ -95,8 +94,8 @@ template <typename Tuple, require_tuple_t<Tuple>* = nullptr>
9594
inline auto deep_copy_vars(Tuple&& arg) {
9695
return stan::math::apply(
9796
[](auto&&... args) {
98-
return std::make_tuple(deep_copy_vars(
99-
std::forward<decltype(args)>(args))...);
97+
return std::make_tuple(
98+
deep_copy_vars(std::forward<decltype(args)>(args))...);
10099
},
101100
std::forward<Tuple>(arg));
102101
}

stan/math/rev/core/save_varis.hpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include <stan/math/rev/meta.hpp>
88
#include <stan/math/rev/core/var.hpp>
99

10-
#include <tuple>
1110
#include <utility>
1211
#include <vector>
1312

@@ -149,13 +148,22 @@ inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args) {
149148
inline vari** save_varis(vari** dest) { return dest; }
150149

151150
/**
152-
* Unpack a tuple and save the varis of each element.
151+
* Save the vari pointers in a tuple into the memory pointed to by dest
152+
* by unpacking the tuple and recursively processing each element.
153+
*
154+
* @tparam Tuple A std::tuple type
155+
* @tparam Pargs Types of remaining arguments
156+
* @param[in, out] dest Pointer to where vari pointers are saved
157+
* @param[in] x A tuple potentially containing vars
158+
* @param[in] args Additional arguments to have their varis saved
159+
* @return Final position of dest pointer
153160
*/
154-
template <typename Tuple, require_tuple_t<Tuple>* = nullptr, typename... Pargs>
161+
template <typename Tuple, require_tuple_t<Tuple>*, typename... Pargs>
155162
inline vari** save_varis(vari** dest, Tuple&& x, Pargs&&... args) {
156163
dest = stan::math::apply(
157164
[dest](auto&&... inner_args) {
158-
return save_varis(dest, std::forward<decltype(inner_args)>(inner_args)...);
165+
return save_varis(
166+
dest, std::forward<decltype(inner_args)>(inner_args)...);
159167
},
160168
std::forward<Tuple>(x));
161169
return save_varis(dest, std::forward<Pargs>(args)...);

test/unit/math/rev/core/accumulate_adjoints_test.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,44 @@ TEST_F(AgradRev, Rev_accumulate_adjoints_std_vector_eigen_matrix_var_arg) {
381381
stan::math::recover_memory();
382382
}
383383

384+
TEST_F(AgradRev, Rev_accumulate_adjoints_tuple_var_int_arg) {
385+
using stan::math::var;
386+
using stan::math::vari;
387+
var a(5.0);
388+
a.vi_->adj_ = 3.0;
389+
int b = 7;
390+
auto arg = std::make_tuple(a, b);
391+
392+
Eigen::VectorXd storage = Eigen::VectorXd::Zero(1000);
393+
double* ptr = stan::math::accumulate_adjoints(storage.data(), arg);
394+
395+
EXPECT_FLOAT_EQ(storage(0), 3.0);
396+
for (int i = 1; i < storage.size(); ++i)
397+
EXPECT_FLOAT_EQ(storage(i), 0.0);
398+
EXPECT_EQ(ptr, storage.data() + 1);
399+
stan::math::recover_memory();
400+
}
401+
402+
TEST_F(AgradRev, Rev_accumulate_adjoints_tuple_var_var_arg) {
403+
using stan::math::var;
404+
using stan::math::vari;
405+
var a(5.0);
406+
a.vi_->adj_ = 3.0;
407+
var b(7.0);
408+
b.vi_->adj_ = 4.0;
409+
auto arg = std::make_tuple(a, b);
410+
411+
Eigen::VectorXd storage = Eigen::VectorXd::Zero(1000);
412+
double* ptr = stan::math::accumulate_adjoints(storage.data(), arg);
413+
414+
EXPECT_FLOAT_EQ(storage(0), 3.0);
415+
EXPECT_FLOAT_EQ(storage(1), 4.0);
416+
for (int i = 2; i < storage.size(); ++i)
417+
EXPECT_FLOAT_EQ(storage(i), 0.0);
418+
EXPECT_EQ(ptr, storage.data() + 2);
419+
stan::math::recover_memory();
420+
}
421+
384422
TEST_F(AgradRev, Rev_accumulate_adjoints_sum) {
385423
using stan::math::var;
386424
using stan::math::vari;

test/unit/math/rev/core/deep_copy_vars_test.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,43 @@ TEST_F(AgradRev, Rev_deep_copy_vars_std_vector_eigen_row_vector_var_arg) {
283283
}
284284
}
285285

286+
TEST_F(AgradRev, Rev_deep_copy_vars_tuple_var_int_arg) {
287+
var a(3.0);
288+
int b = 5;
289+
auto arg = std::make_tuple(a, b);
290+
291+
auto out = stan::math::deep_copy_vars(arg);
292+
293+
EXPECT_EQ(std::get<0>(out).val(), a.val());
294+
EXPECT_NE(std::get<0>(out).vi_, a.vi_);
295+
EXPECT_EQ(std::get<1>(out), b);
296+
}
297+
298+
TEST_F(AgradRev, Rev_deep_copy_vars_tuple_var_double_arg) {
299+
var a(3.0);
300+
double b = 5.0;
301+
auto arg = std::make_tuple(a, b);
302+
303+
auto out = stan::math::deep_copy_vars(arg);
304+
305+
EXPECT_EQ(std::get<0>(out).val(), a.val());
306+
EXPECT_NE(std::get<0>(out).vi_, a.vi_);
307+
EXPECT_EQ(std::get<1>(out), b);
308+
}
309+
310+
TEST_F(AgradRev, Rev_deep_copy_vars_tuple_var_var_arg) {
311+
var a(3.0);
312+
var b(7.0);
313+
auto arg = std::make_tuple(a, b);
314+
315+
auto out = stan::math::deep_copy_vars(arg);
316+
317+
EXPECT_EQ(std::get<0>(out).val(), a.val());
318+
EXPECT_NE(std::get<0>(out).vi_, a.vi_);
319+
EXPECT_EQ(std::get<1>(out).val(), b.val());
320+
EXPECT_NE(std::get<1>(out).vi_, b.vi_);
321+
}
322+
286323
TEST_F(AgradRev, Rev_deep_copy_vars_std_vector_eigen_matrix_var_arg) {
287324
Eigen::Matrix<var, Eigen::Dynamic, Eigen::Dynamic> arg_(5, 3);
288325
std::vector<Eigen::Matrix<var, Eigen::Dynamic, Eigen::Dynamic>> arg(2, arg_);

test/unit/math/rev/core/save_varis_test.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,39 @@ TEST_F(AgradRev, Rev_save_varis_std_vector_eigen_matrix_var_arg) {
332332
EXPECT_EQ(ptr, storage.data() + num_vars);
333333
}
334334

335+
TEST_F(AgradRev, Rev_save_varis_tuple_var_int_arg) {
336+
var a(5.0);
337+
int b = 3;
338+
auto arg = std::make_tuple(a, b);
339+
340+
std::vector<vari*> storage(1000, nullptr);
341+
vari** ptr = stan::math::save_varis(storage.data(), arg);
342+
343+
size_t num_vars = stan::math::count_vars(arg);
344+
EXPECT_EQ(num_vars, 1);
345+
EXPECT_EQ(storage[0], a.vi_);
346+
for (int i = num_vars; i < storage.size(); ++i)
347+
EXPECT_EQ(storage[i], nullptr);
348+
EXPECT_EQ(ptr, storage.data() + num_vars);
349+
}
350+
351+
TEST_F(AgradRev, Rev_save_varis_tuple_var_var_arg) {
352+
var a(5.0);
353+
var b(7.0);
354+
auto arg = std::make_tuple(a, b);
355+
356+
std::vector<vari*> storage(1000, nullptr);
357+
vari** ptr = stan::math::save_varis(storage.data(), arg);
358+
359+
size_t num_vars = stan::math::count_vars(arg);
360+
EXPECT_EQ(num_vars, 2);
361+
EXPECT_EQ(storage[0], a.vi_);
362+
EXPECT_EQ(storage[1], b.vi_);
363+
for (int i = num_vars; i < storage.size(); ++i)
364+
EXPECT_EQ(storage[i], nullptr);
365+
EXPECT_EQ(ptr, storage.data() + num_vars);
366+
}
367+
335368
TEST_F(AgradRev, Rev_save_varis_sum) {
336369
int arg1 = 1;
337370
double arg2 = 1.0;

0 commit comments

Comments
 (0)