Skip to content

Commit 0426e8a

Browse files
committed
fix assignment for nullptr var_value<matrix> and for assigning expressions
1 parent e43fc08 commit 0426e8a

3 files changed

Lines changed: 107 additions & 7 deletions

File tree

stan/math/rev/core/arena_matrix.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ class arena_matrix : public Eigen::Map<MatrixType> {
128128
Base::operator=(a);
129129
return *this;
130130
}
131+
template <typename T>
132+
void hard_copy(const T& x) {
133+
Base::operator=(x);
134+
}
131135
};
132136

133137
} // namespace math

stan/math/rev/core/var.hpp

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ class var_value<T, internal::require_matrix_var_value<T>> {
390390
reverse_pass_callback(
391391
[this_vi = this->vi_, other_vi = other.vi_]() mutable {
392392
other_vi->adj_ += this_vi->adj_;
393+
//
394+
this_vi->adj_.setZero();
393395
});
394396
}
395397

@@ -1020,9 +1022,9 @@ class var_value<T, internal::require_matrix_var_value<T>> {
10201022
* @param other the value to assign
10211023
* @return this
10221024
*/
1023-
template <typename S, require_assignable_t<value_type, S>* = nullptr,
1024-
require_all_plain_type_t<T, S>* = nullptr,
1025-
require_not_same_t<plain_type_t<T>, plain_type_t<S>>* = nullptr>
1025+
template <typename S, typename T_ = T, require_assignable_t<value_type, S>* = nullptr,
1026+
require_all_plain_type_t<T_, S>* = nullptr,
1027+
require_not_same_t<plain_type_t<T_>, plain_type_t<S>>* = nullptr>
10261028
inline var_value<T>& operator=(const var_value<S>& other) {
10271029
static_assert(
10281030
EIGEN_PREDICATE_SAME_MATRIX_SIZE(T, S),
@@ -1032,16 +1034,63 @@ class var_value<T, internal::require_matrix_var_value<T>> {
10321034
}
10331035

10341036
/**
1035-
* Assignment of another var value, when either this or the other one does not
1037+
* Assignment of another var value, when the `this` does not
10361038
* contain a plain type.
1037-
* @tparam S type of the value in the `var_value` to assing
1039+
* @tparam S type of the value in the `var_value` to assign
1040+
* @param other the value to assign
1041+
* @return this
1042+
*/
1043+
template <typename S, typename T_ = T,
1044+
require_assignable_t<value_type, S>* = nullptr,
1045+
require_not_plain_type_t<S>* = nullptr,
1046+
require_plain_type_t<T_>* = nullptr>
1047+
inline var_value<T>& operator=(const var_value<S>& other) {
1048+
// If vi_ is nullptr then the var needs initialized via copy constructor
1049+
if (!(this->vi_)) {
1050+
*this = var_value<T>(other);
1051+
return *this;
1052+
}
1053+
arena_t<plain_type_t<T>> prev_val(vi_->val_.rows(), vi_->val_.cols());
1054+
prev_val.hard_copy(vi_->val_);
1055+
vi_->val_.hard_copy(other.val());
1056+
// no need to change any adjoints - these are just zeros before the reverse
1057+
// pass
1058+
1059+
reverse_pass_callback(
1060+
[this_vi = this->vi_, other_vi = other.vi_, prev_val]() mutable {
1061+
this_vi->val_.hard_copy(prev_val);
1062+
1063+
// we have no way of detecting aliasing between this->vi_->adj_ and
1064+
// other.vi_->adj_, so we must copy adjoint before reseting to zero
1065+
1066+
// we can reuse prev_val instead of allocating a new matrix
1067+
prev_val.hard_copy(this_vi->adj_);
1068+
this_vi->adj_.setZero();
1069+
other_vi->adj_ += prev_val;
1070+
});
1071+
return *this;
1072+
}
1073+
/**
1074+
* Assignment of another var value, when either both `this` or other does not
1075+
* contain a plain type.
1076+
* @tparam S type of the value in the `var_value` to assign
10381077
* @param other the value to assign
10391078
* @return this
10401079
*/
10411080
template <typename S, typename T_ = T,
10421081
require_assignable_t<value_type, S>* = nullptr,
1043-
require_any_not_plain_type_t<T_, S>* = nullptr>
1082+
require_any_not_plain_type_t<T_, S>* = nullptr,
1083+
require_not_plain_type_t<T_>* = nullptr>
10441084
inline var_value<T>& operator=(const var_value<S>& other) {
1085+
// If vi_ is nullptr then the var needs initialized via copy constructor
1086+
if (!(this->vi_)) {
1087+
[]() STAN_COLD_PATH {
1088+
throw std::domain_error(
1089+
"var_value<matrix>::operator=(var_value<expression>):"
1090+
" Internal Bug! Please report this with an example"
1091+
" of your model to the Stan math github repository.");
1092+
}();
1093+
}
10451094
arena_t<plain_type_t<T>> prev_val = vi_->val_;
10461095
vi_->val_ = other.val();
10471096
// no need to change any adjoints - these are just zeros before the reverse
@@ -1055,13 +1104,14 @@ class var_value<T, internal::require_matrix_var_value<T>> {
10551104
// other.vi_->adj_, so we must copy adjoint before reseting to zero
10561105

10571106
// we can reuse prev_val instead of allocating a new matrix
1058-
prev_val = this_vi->adj_;
1107+
prev_val.hard_copy(this_vi->adj_);
10591108
this_vi->adj_.setZero();
10601109
other_vi->adj_ += prev_val;
10611110
});
10621111
return *this;
10631112
}
10641113

1114+
10651115
/**
10661116
* No-op to match with Eigen methods which call eval
10671117
*/

test/unit/math/rev/core/var_test.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,3 +910,49 @@ TEST_F(AgradRev, matrix_compile_time_conversions) {
910910
EXPECT_MATRIX_FLOAT_EQ(colvec.val(), rowvec.val());
911911
EXPECT_MATRIX_FLOAT_EQ(x11.val(), rowvec.val());
912912
}
913+
914+
TEST_F(AgradRev, assign_nan) {
915+
using stan::math::var_value;
916+
using var_vector = var_value<Eigen::Matrix<double,-1,1>>;
917+
using stan::math::var;
918+
Eigen::VectorXd x_val(10);
919+
for (int i = 0; i < 10; ++i) {
920+
x_val(i) = i + 0.1;
921+
}
922+
var_vector x(x_val);
923+
var_vector y = var_vector(Eigen::Matrix<double,-1,1>::Constant(10, std::numeric_limits<double>::quiet_NaN()));
924+
y = stan::math::head(x, 10);
925+
var sigma = 1.0;
926+
var lp = stan::math::normal_lpdf<false>(y, 0, sigma);
927+
lp.grad();
928+
Eigen::VectorXd x_ans_adj(10);
929+
for (int i = 0; i < 10; ++i) {
930+
x_ans_adj(i) = -(i + 0.1);
931+
}
932+
EXPECT_MATRIX_EQ(x.adj(), x_ans_adj);
933+
Eigen::VectorXd y_ans_adj = Eigen::VectorXd::Zero(10);
934+
EXPECT_MATRIX_EQ(y_ans_adj, y.adj());
935+
}
936+
937+
TEST_F(AgradRev, assign_nullptr_vari) {
938+
using stan::math::var_value;
939+
using var_vector = var_value<Eigen::Matrix<double,-1,1>>;
940+
using stan::math::var;
941+
Eigen::VectorXd x_val(10);
942+
for (int i = 0; i < 10; ++i) {
943+
x_val(i) = i + 0.1;
944+
}
945+
var_vector x(x_val);
946+
var_vector y;
947+
y = stan::math::head(x, 10);
948+
var sigma = 1.0;
949+
var lp = stan::math::normal_lpdf<false>(y, 0, sigma);
950+
lp.grad();
951+
Eigen::VectorXd x_ans_adj(10);
952+
for (int i = 0; i < 10; ++i) {
953+
x_ans_adj(i) = -(i + 0.1);
954+
}
955+
EXPECT_MATRIX_EQ(x.adj(), x_ans_adj);
956+
Eigen::VectorXd y_ans_adj = Eigen::VectorXd::Zero(10);
957+
EXPECT_MATRIX_EQ(y_ans_adj, y.adj());
958+
}

0 commit comments

Comments
 (0)