Skip to content

Commit ce7af32

Browse files
committed
allow as_array_or_scalar to return a holder of a map for std vector types
1 parent cef520c commit ce7af32

2 files changed

Lines changed: 12 additions & 12 deletions

File tree

stan/math/prim/fun/as_array_or_scalar.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,8 @@ template <typename T, require_std_vector_t<T>* = nullptr,
6969
inline auto as_array_or_scalar(T&& v) {
7070
using arr_t = Eigen::Array<value_type_t<T>, Eigen::Dynamic, 1>;
7171
using T_map = Eigen::Map<const arr_t>;
72-
if constexpr (std::is_rvalue_reference_v<T&&>) {
73-
return make_holder([](auto&& x) { return T_map(x.data(), x.size()); },
74-
std::forward<T>(v));
75-
} else {
76-
return arr_t(T_map(v.data(), v.size()));
77-
}
72+
return make_holder([](auto&& x) { return T_map(x.data(), x.size()).matrix().array(); },
73+
std::forward<T>(v));
7874
}
7975

8076
/**

stan/math/prim/fun/as_column_vector_or_scalar.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ inline auto as_column_vector_or_scalar(T&& a) {
6464
/**
6565
* Converts `std::vector` to a column vector.
6666
*
67+
* @note The math library's reverse mode assumes that `Eigen::Map`
68+
* types are allocated and owned elsewhere so we cannot just return
69+
* back a map here else the reverse mode library
70+
* may try to access into a dangling pointer. Instead we wrap
71+
* the `Eigen::Map` in a `Holder` to trick the reverse mode library
72+
* into not thinking this is a map. The `.array().matrix()` inside the
73+
* holder is so that the holder thinks it is returning an expression.
6774
* @tparam T `std::vector` type.
6875
* @param a Specified vector.
6976
* @return input converted to a column vector.
@@ -74,13 +81,10 @@ inline auto as_column_vector_or_scalar(T&& a) {
7481
using optionally_const_vector
7582
= std::conditional_t<std::is_const<std::remove_reference_t<T>>::value,
7683
const plain_vector, plain_vector>;
84+
7785
using T_map = Eigen::Map<optionally_const_vector>;
78-
if constexpr (std::is_rvalue_reference_v<T&&>) {
79-
return make_holder([](auto&& x) { return T_map(x.data(), x.size()); },
80-
std::forward<T>(a));
81-
} else {
82-
return plain_vector(T_map(a.data(), a.size()));
83-
}
86+
return make_holder([](auto&& x) { return T_map(x.data(), x.size()).array().matrix(); },
87+
std::forward<T>(a));
8488
}
8589

8690
} // namespace math

0 commit comments

Comments
 (0)