Skip to content

Commit 1e105c3

Browse files
committed
Improve numerical stability of normal quantile gradients
1 parent 42d94c4 commit 1e105c3

2 files changed

Lines changed: 18 additions & 17 deletions

File tree

stan/math/rev/fun/inv_Phi.hpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
#include <stan/math/rev/meta.hpp>
55
#include <stan/math/rev/core.hpp>
6-
#include <stan/math/prim/fun/constants.hpp>
76
#include <stan/math/prim/fun/inv_Phi.hpp>
7+
#include <stan/math/prim/prob/std_normal_lpdf.hpp>
8+
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
89
#include <cmath>
910

1011
namespace stan {
@@ -19,8 +20,9 @@ namespace math {
1920
* @return The unit normal inverse cdf evaluated at p
2021
*/
2122
inline var inv_Phi(const var& p) {
22-
return make_callback_var(inv_Phi(p.val()), [p](auto& vi) mutable {
23-
p.adj() += vi.adj() * SQRT_TWO_PI / std::exp(-0.5 * vi.val() * vi.val());
23+
double val = inv_Phi(p.val());
24+
return make_callback_var(val, [p, val](auto& vi) mutable {
25+
p.adj() += vi.adj() * exp(-std_normal_lpdf(val));
2426
});
2527
}
2628

@@ -33,9 +35,12 @@ inline var inv_Phi(const var& p) {
3335
*/
3436
template <typename T, require_var_matrix_t<T>* = nullptr>
3537
inline auto inv_Phi(const T& p) {
36-
return make_callback_var(inv_Phi(p.val()), [p](auto& vi) mutable {
37-
p.adj().array() += vi.adj().array() * SQRT_TWO_PI
38-
/ (-0.5 * vi.val().array().square()).exp();
38+
const auto& arena_rtn = to_arena(inv_Phi(p.val()));
39+
return make_callback_var(arena_rtn, [p, arena_rtn](auto& vi) mutable {
40+
p.adj() += apply_scalar_binary(vi.adj(), arena_rtn.val(),
41+
[](const auto& adj, const auto& rtn_val) {
42+
return adj * exp(-std_normal_lpdf(rtn_val));
43+
});
3944
});
4045
}
4146

stan/math/rev/prob/std_normal_log_qf.hpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
#include <stan/math/rev/meta.hpp>
55
#include <stan/math/rev/core.hpp>
6-
#include <stan/math/prim/fun/constants.hpp>
7-
#include <stan/math/prim/fun/sign.hpp>
86
#include <stan/math/prim/prob/std_normal_log_qf.hpp>
7+
#include <stan/math/prim/functor/apply_scalar_ternary.hpp>
98
#include <cmath>
109

1110
namespace stan {
@@ -19,15 +18,12 @@ namespace math {
1918
*/
2019
template <typename T, require_stan_scalar_or_eigen_t<T>* = nullptr>
2120
inline auto std_normal_log_qf(const var_value<T>& log_p) {
22-
return make_callback_var(
23-
std_normal_log_qf(log_p.val()), [log_p](auto& vi) mutable {
24-
auto vi_array = as_array_or_scalar(vi.val());
25-
auto vi_sign = sign(as_array_or_scalar(vi.adj()));
26-
27-
const auto& deriv = as_array_or_scalar(log_p).val()
28-
+ log(as_array_or_scalar(vi.adj()) * vi_sign)
29-
- NEG_LOG_SQRT_TWO_PI + 0.5 * square(vi_array);
30-
as_array_or_scalar(log_p).adj() += vi_sign * exp(deriv);
21+
const auto& arena_rtn = to_arena(std_normal_log_qf(log_p.val()));
22+
return make_callback_var(arena_rtn, [log_p, arena_rtn](auto& vi) mutable {
23+
log_p.adj() += apply_scalar_ternary(
24+
[](const auto& adj, const auto& logp_val, const auto& rtn_val) {
25+
return adj * exp(logp_val - std_normal_lpdf(rtn_val));
26+
}, vi.adj(), log_p.val(), arena_rtn.val());
3127
});
3228
}
3329

0 commit comments

Comments
 (0)