Skip to content

Commit fe521e0

Browse files
authored
Merge pull request #3089 from stan-dev/feature/vari-set-adj
adds constructor to vari for passing both initial values and adjoints
2 parents 87bb8a7 + 29b366e commit fe521e0

5 files changed

Lines changed: 55 additions & 7 deletions

File tree

stan/math/opencl/rev/vari.hpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ class vari_value<T, require_matrix_cl_t<T>> : public chainable_alloc,
227227
require_vt_same<T, S>* = nullptr>
228228
explicit vari_value(const S& x)
229229
: chainable_alloc(), vari_cl_base<T>(x, constant(0, x.rows(), x.cols())) {
230-
ChainableStack::instance_->var_stack_.push_back(this);
230+
ChainableStack::instance_->var_nochain_stack_.push_back(this);
231231
}
232232

233233
/**
@@ -259,6 +259,26 @@ class vari_value<T, require_matrix_cl_t<T>> : public chainable_alloc,
259259
}
260260
}
261261

262+
/**
263+
* Construct a dense Eigen variable implementation from a
264+
* preconstructed values and adjoints.
265+
*
266+
* All constructed variables are not added to the stack. Variables
267+
* should be constructed before variables on which they depend
268+
* to insure proper partial derivative propagation.
269+
* @tparam S A dense Eigen type that is convertible to `value_type`
270+
* @tparam K A dense Eigen type that is convertible to `value_type`
271+
* @param val Matrix of values
272+
* @param adj Matrix of adjoints
273+
*/
274+
template <typename S, typename K, require_convertible_t<T, S>* = nullptr,
275+
require_convertible_t<T, K>* = nullptr>
276+
explicit vari_value(S&& val, K&& adj)
277+
: chainable_alloc(),
278+
vari_cl_base<T>(std::forward<S>(val), std::forward<K>(adj)) {
279+
ChainableStack::instance_->var_nochain_stack_.push_back(this);
280+
}
281+
262282
/**
263283
* Set the adjoint value of this variable to 0. This is used to
264284
* reset adjoints before propagating derivatives again (for

stan/math/rev/core/callback_vari.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct callback_vari : public vari_value<T> {
1515
template <typename S,
1616
require_same_t<plain_type_t<T>, plain_type_t<S>>* = nullptr>
1717
explicit callback_vari(S&& value, F&& rev_functor)
18-
: vari_value<T>(std::move(value)),
18+
: vari_value<T>(std::move(value), true),
1919
rev_functor_(std::forward<F>(rev_functor)) {}
2020

2121
inline void chain() final { rev_functor_(*this); }

stan/math/rev/core/vari.hpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -678,11 +678,9 @@ class vari_value<T, require_all_t<is_plain_type<T>, is_eigen_dense_base<T>>>
678678
* Construct a dense Eigen variable implementation from a value. The
679679
* adjoint is initialized to zero.
680680
*
681-
* All constructed variables are added to the stack. Variables
681+
* All constructed variables are added to the no chain stack. Variables
682682
* should be constructed before variables on which they depend
683-
* to insure proper partial derivative propagation. During
684-
* derivative propagation, the chain() method of each variable
685-
* will be called in the reverse order of construction.
683+
* to insure proper partial derivative propagation.
686684
*
687685
* @tparam S A dense Eigen type that is convertible to `value_type`
688686
* @param x Value of the constructed variable.
@@ -699,7 +697,7 @@ class vari_value<T, require_all_t<is_plain_type<T>, is_eigen_dense_base<T>>>
699697
? x.rows()
700698
: x.cols()) {
701699
adj_.setZero();
702-
ChainableStack::instance_->var_stack_.push_back(this);
700+
ChainableStack::instance_->var_nochain_stack_.push_back(this);
703701
}
704702

705703
/**
@@ -736,6 +734,24 @@ class vari_value<T, require_all_t<is_plain_type<T>, is_eigen_dense_base<T>>>
736734
}
737735
}
738736

737+
/**
738+
* Construct a dense Eigen variable implementation from a
739+
* preconstructed values and adjoints.
740+
*
741+
* All constructed variables are not added to the stack. Variables
742+
* should be constructed before variables on which they depend
743+
* to insure proper partial derivative propagation.
744+
* @tparam S A dense Eigen type that is convertible to `value_type`
745+
* @tparam K A dense Eigen type that is convertible to `value_type`
746+
* @param val Matrix of values
747+
* @param adj Matrix of adjoints
748+
*/
749+
template <typename S, typename K, require_assignable_t<T, S>* = nullptr,
750+
require_assignable_t<T, K>* = nullptr>
751+
explicit vari_value(const S& val, const K& adj) : val_(val), adj_(adj) {
752+
ChainableStack::instance_->var_nochain_stack_.push_back(this);
753+
}
754+
739755
protected:
740756
template <typename S, require_not_same_t<T, S>* = nullptr>
741757
explicit vari_value(const vari_value<S>* x) : val_(x->val_), adj_(x->adj_) {}

test/unit/math/opencl/rev/vari_test.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ TEST(AgradRev, matrix_cl_vari_block) {
2020
stan::math::from_matrix_cl(B.block(0, 1, 2, 2).val_));
2121
EXPECT_MATRIX_EQ(b.block(0, 1, 2, 2),
2222
stan::math::from_matrix_cl(B.block(0, 1, 2, 2).adj_));
23+
vari_value<stan::math::matrix_cl<double>> C(a_cl, a_cl);
24+
EXPECT_MATRIX_EQ(a.block(0, 1, 2, 2),
25+
stan::math::from_matrix_cl(C.block(0, 1, 2, 2).val_));
26+
EXPECT_MATRIX_EQ(a.block(0, 1, 2, 2),
27+
stan::math::from_matrix_cl(C.block(0, 1, 2, 2).adj_));
2328
}
2429

2530
#endif

test/unit/math/rev/core/vari_test.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ TEST(AgradRevVari, arena_matrix_matrix_vari) {
7373
EXPECT_MATRIX_FLOAT_EQ((*C).val(), x);
7474
auto* D = new vari_value<Eigen::MatrixXd>(x_ref, true);
7575
EXPECT_MATRIX_FLOAT_EQ((*D).val(), x);
76+
auto* E = new vari_value<Eigen::MatrixXd>(x, (x.array() + 1.0).matrix());
77+
EXPECT_MATRIX_FLOAT_EQ((*E).val(), x);
78+
EXPECT_MATRIX_FLOAT_EQ((*E).adj(), (x.array() + 1.0).matrix());
79+
auto* F = new vari_value<Eigen::MatrixXd>(x, x);
80+
EXPECT_MATRIX_FLOAT_EQ((*F).val(), x);
81+
EXPECT_MATRIX_FLOAT_EQ((*F).adj(), x);
82+
EXPECT_EQ((*F).val().data(), (*F).adj().data());
7683
}
7784

7885
TEST(AgradRevVari, dense_vari_matrix_views) {

0 commit comments

Comments
 (0)