33
44#include < stan/math/rev/meta.hpp>
55#include < stan/math/rev/core.hpp>
6- #include < stan/math/prim/fun/constants.hpp>
76#include < stan/math/prim/fun/inv_Phi.hpp>
7+ #include < stan/math/prim/prob/std_normal_lpdf.hpp>
8+ #include < stan/math/prim/functor/apply_scalar_binary.hpp>
89#include < cmath>
910
1011namespace stan {
@@ -19,8 +20,9 @@ namespace math {
1920 * @return The unit normal inverse cdf evaluated at p
2021 */
2122inline var inv_Phi (const var& p) {
22- return make_callback_var (inv_Phi (p.val ()), [p](auto & vi) mutable {
23- p.adj () += vi.adj () * SQRT_TWO_PI / std::exp (-0.5 * vi.val () * vi.val ());
23+ double val = inv_Phi (p.val ());
24+ return make_callback_var (val, [p, val](auto & vi) mutable {
25+ p.adj () += vi.adj () * exp (-std_normal_lpdf (val));
2426 });
2527}
2628
@@ -33,9 +35,12 @@ inline var inv_Phi(const var& p) {
3335 */
3436template <typename T, require_var_matrix_t <T>* = nullptr >
3537inline auto inv_Phi (const T& p) {
36- return make_callback_var (inv_Phi (p.val ()), [p](auto & vi) mutable {
37- p.adj ().array () += vi.adj ().array () * SQRT_TWO_PI
38- / (-0.5 * vi.val ().array ().square ()).exp ();
38+ const auto & arena_rtn = to_arena (inv_Phi (p.val ()));
39+ return make_callback_var (arena_rtn, [p, arena_rtn](auto & vi) mutable {
40+ p.adj () += apply_scalar_binary (vi.adj (), arena_rtn.val (),
41+ [](const auto & adj, const auto & rtn_val) {
42+ return adj * exp (-std_normal_lpdf (rtn_val));
43+ });
3944 });
4045}
4146
0 commit comments