@@ -390,6 +390,8 @@ class var_value<T, internal::require_matrix_var_value<T>> {
390390 reverse_pass_callback (
391391 [this_vi = this ->vi_ , other_vi = other.vi_ ]() mutable {
392392 other_vi->adj_ += this_vi->adj_ ;
393+ //
394+ this_vi->adj_ .setZero ();
393395 });
394396 }
395397
@@ -1020,9 +1022,9 @@ class var_value<T, internal::require_matrix_var_value<T>> {
10201022 * @param other the value to assign
10211023 * @return this
10221024 */
1023- template <typename S, require_assignable_t <value_type, S>* = nullptr ,
1024- require_all_plain_type_t <T , S>* = nullptr ,
1025- require_not_same_t <plain_type_t <T >, plain_type_t <S>>* = nullptr >
1025+ template <typename S, typename T_ = T, require_assignable_t <value_type, S>* = nullptr ,
1026+ require_all_plain_type_t <T_ , S>* = nullptr ,
1027+ require_not_same_t <plain_type_t <T_ >, plain_type_t <S>>* = nullptr >
10261028 inline var_value<T>& operator =(const var_value<S>& other) {
10271029 static_assert (
10281030 EIGEN_PREDICATE_SAME_MATRIX_SIZE (T, S),
@@ -1032,16 +1034,63 @@ class var_value<T, internal::require_matrix_var_value<T>> {
10321034 }
10331035
10341036 /* *
1035- * Assignment of another var value, when either this or the other one does not
1037+ * Assignment of another var value, when the `this` does not
10361038 * contain a plain type.
1037- * @tparam S type of the value in the `var_value` to assing
1039+ * @tparam S type of the value in the `var_value` to assign
1040+ * @param other the value to assign
1041+ * @return this
1042+ */
1043+ template <typename S, typename T_ = T,
1044+ require_assignable_t <value_type, S>* = nullptr ,
1045+ require_not_plain_type_t <S>* = nullptr ,
1046+ require_plain_type_t <T_>* = nullptr >
1047+ inline var_value<T>& operator =(const var_value<S>& other) {
1048+ // If vi_ is nullptr then the var needs initialized via copy constructor
1049+ if (!(this ->vi_ )) {
1050+ *this = var_value<T>(other);
1051+ return *this ;
1052+ }
1053+ arena_t <plain_type_t <T>> prev_val (vi_->val_ .rows (), vi_->val_ .cols ());
1054+ prev_val.hard_copy (vi_->val_ );
1055+ vi_->val_ .hard_copy (other.val ());
1056+ // no need to change any adjoints - these are just zeros before the reverse
1057+ // pass
1058+
1059+ reverse_pass_callback (
1060+ [this_vi = this ->vi_ , other_vi = other.vi_ , prev_val]() mutable {
1061+ this_vi->val_ .hard_copy (prev_val);
1062+
1063+ // we have no way of detecting aliasing between this->vi_->adj_ and
1064+ // other.vi_->adj_, so we must copy adjoint before reseting to zero
1065+
1066+ // we can reuse prev_val instead of allocating a new matrix
1067+ prev_val.hard_copy (this_vi->adj_ );
1068+ this_vi->adj_ .setZero ();
1069+ other_vi->adj_ += prev_val;
1070+ });
1071+ return *this ;
1072+ }
1073+ /* *
1074+ * Assignment of another var value, when either both `this` or other does not
1075+ * contain a plain type.
1076+ * @tparam S type of the value in the `var_value` to assign
10381077 * @param other the value to assign
10391078 * @return this
10401079 */
10411080 template <typename S, typename T_ = T,
10421081 require_assignable_t <value_type, S>* = nullptr ,
1043- require_any_not_plain_type_t <T_, S>* = nullptr >
1082+ require_any_not_plain_type_t <T_, S>* = nullptr ,
1083+ require_not_plain_type_t <T_>* = nullptr >
10441084 inline var_value<T>& operator =(const var_value<S>& other) {
1085+ // If vi_ is nullptr then the var needs initialized via copy constructor
1086+ if (!(this ->vi_ )) {
1087+ []() STAN_COLD_PATH {
1088+ throw std::domain_error (
1089+ " var_value<matrix>::operator=(var_value<expression>):"
1090+ " Internal Bug! Please report this with an example"
1091+ " of your model to the Stan math github repository." );
1092+ }();
1093+ }
10451094 arena_t <plain_type_t <T>> prev_val = vi_->val_ ;
10461095 vi_->val_ = other.val ();
10471096 // no need to change any adjoints - these are just zeros before the reverse
@@ -1055,13 +1104,14 @@ class var_value<T, internal::require_matrix_var_value<T>> {
10551104 // other.vi_->adj_, so we must copy adjoint before reseting to zero
10561105
10571106 // we can reuse prev_val instead of allocating a new matrix
1058- prev_val = this_vi->adj_ ;
1107+ prev_val. hard_copy ( this_vi->adj_ ) ;
10591108 this_vi->adj_ .setZero ();
10601109 other_vi->adj_ += prev_val;
10611110 });
10621111 return *this ;
10631112 }
10641113
1114+
10651115 /* *
10661116 * No-op to match with Eigen methods which call eval
10671117 */
0 commit comments