Skip to content

Commit b713e58

Browse files
committed
Update vectorisation
1 parent a67a993 commit b713e58

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

stan/math/rev/prob/std_normal_log_qf.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
#include <stan/math/rev/meta.hpp>
55
#include <stan/math/rev/core.hpp>
66
#include <stan/math/prim/prob/std_normal_log_qf.hpp>
7-
#include <stan/math/prim/functor/apply_scalar_ternary.hpp>
7+
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
8+
#include <stan/math/prim/fun/elt_multiply.hpp>
89
#include <cmath>
910

1011
namespace stan {
@@ -20,11 +21,11 @@ template <typename T, require_stan_scalar_or_eigen_t<T>* = nullptr>
2021
inline auto std_normal_log_qf(const var_value<T>& log_p) {
2122
const auto& arena_rtn = to_arena(std_normal_log_qf(log_p.val()));
2223
return make_callback_var(arena_rtn, [log_p, arena_rtn](auto& vi) mutable {
23-
log_p.adj() += apply_scalar_ternary(
24-
[](const auto& adj, const auto& logp_val, const auto& rtn_val) {
25-
return adj * exp(logp_val - std_normal_lpdf(rtn_val));
26-
},
27-
vi.adj(), log_p.val(), arena_rtn.val());
24+
auto deriv = apply_scalar_binary(log_p.val(), arena_rtn,
25+
[](const auto& logp_val, const auto& rtn_val) {
26+
return exp(logp_val - std_normal_lpdf(rtn_val));
27+
});
28+
log_p.adj() += elt_multiply(vi.adj(), deriv);
2829
});
2930
}
3031

0 commit comments

Comments
 (0)