Skip to content

Commit 85a248e

Browse files
committed
Mix fixes
1 parent bae12ac commit 85a248e

7 files changed

Lines changed: 16 additions & 11 deletions

File tree

stan/math/fwd/fun/log1m_inv_logit.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace math {
2020
template <typename T>
2121
inline fvar<T> log1m_inv_logit(const fvar<T>& x) {
2222
using std::exp;
23-
return fvar<T>(log1m_inv_logit(x.val_), -x.d_ / (1 + exp(-x.val_)));
23+
return fvar<T>(log1m_inv_logit(x.val_), -x.d_ * inv_logit(x.val_));
2424
}
2525

2626
} // namespace math

stan/math/fwd/fun/log1p_exp.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace math {
1212
template <typename T>
1313
inline fvar<T> log1p_exp(const fvar<T>& x) {
1414
using std::exp;
15-
return fvar<T>(log1p_exp(x.val_), x.d_ / (1 + exp(-x.val_)));
15+
return fvar<T>(log1p_exp(x.val_), x.d_ * inv_logit(x.val_));
1616
}
1717

1818
} // namespace math

stan/math/fwd/fun/log_inv_logit.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/fwd/meta.hpp>
55
#include <stan/math/fwd/core.hpp>
6+
#include <stan/math/prim/fun/inv_logit.hpp>
67
#include <stan/math/prim/fun/log_inv_logit.hpp>
78
#include <cmath>
89

@@ -12,7 +13,7 @@ namespace math {
1213
template <typename T>
1314
inline fvar<T> log_inv_logit(const fvar<T>& x) {
1415
using std::exp;
15-
return fvar<T>(log_inv_logit(x.val_), x.d_ / (1 + exp(x.val_)));
16+
return fvar<T>(log_inv_logit(x.val_), x.d_ * inv_logit(-x.val_));
1617
}
1718
} // namespace math
1819
} // namespace stan

stan/math/fwd/fun/log_sum_exp.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ template <typename T>
1818
inline fvar<T> log_sum_exp(const fvar<T>& x1, const fvar<T>& x2) {
1919
using std::exp;
2020
return fvar<T>(log_sum_exp(x1.val_, x2.val_),
21-
x1.d_ / (1 + exp(x2.val_ - x1.val_))
22-
+ x2.d_ / (exp(x1.val_ - x2.val_) + 1));
21+
x1.d_ * inv_logit(-(x2.val_ - x1.val_))
22+
+ x2.d_ * inv_logit(-(x1.val_ - x2.val_)));
2323
}
2424

2525
template <typename T>
@@ -28,7 +28,7 @@ inline fvar<T> log_sum_exp(double x1, const fvar<T>& x2) {
2828
if (x1 == NEGATIVE_INFTY) {
2929
return fvar<T>(x2.val_, x2.d_);
3030
}
31-
return fvar<T>(log_sum_exp(x1, x2.val_), x2.d_ / (exp(x1 - x2.val_) + 1));
31+
return fvar<T>(log_sum_exp(x1, x2.val_), x2.d_ * inv_logit(-(x1 - x2.val_)));
3232
}
3333

3434
template <typename T>

stan/math/prim/prob/logistic_cdf.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <stan/math/prim/fun/size.hpp>
1111
#include <stan/math/prim/fun/size_zero.hpp>
1212
#include <stan/math/prim/fun/value_of.hpp>
13+
#include <stan/math/prim/fun/inv_logit.hpp>
1314
#include <stan/math/prim/prob/logistic_log.hpp>
1415
#include <stan/math/prim/functor/partials_propagator.hpp>
1516
#include <cmath>
@@ -70,8 +71,8 @@ return_type_t<T_y, T_loc, T_scale> logistic_cdf(const T_y& y, const T_loc& mu,
7071
const T_partials_return sigma_dbl = sigma_vec.val(n);
7172
const T_partials_return sigma_inv_vec = 1.0 / sigma_vec.val(n);
7273

73-
const T_partials_return Pn
74-
= 1.0 / (1.0 + exp(-(y_dbl - mu_dbl) * sigma_inv_vec));
74+
// TODO(Andrew) Further simplify derivatives and log scale below
75+
const T_partials_return Pn = inv_logit((y_dbl - mu_dbl) * sigma_inv_vec);
7576

7677
P *= Pn;
7778

stan/math/prim/prob/logistic_lccdf.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/prim/fun/constants.hpp>
77
#include <stan/math/prim/fun/exp.hpp>
88
#include <stan/math/prim/fun/log.hpp>
9+
#include <stan/math/prim/fun/inv_logit.hpp>
910
#include <stan/math/prim/fun/max_size.hpp>
1011
#include <stan/math/prim/fun/scalar_seq_view.hpp>
1112
#include <stan/math/prim/fun/size.hpp>
@@ -71,8 +72,9 @@ return_type_t<T_y, T_loc, T_scale> logistic_lccdf(const T_y& y, const T_loc& mu,
7172
const T_partials_return sigma_dbl = sigma_vec.val(n);
7273
const T_partials_return sigma_inv_vec = 1.0 / sigma_vec.val(n);
7374

75+
// TODO(Andrew) Further simplify derivatives and log-scale below
7476
const T_partials_return Pn
75-
= 1.0 - 1.0 / (1.0 + exp(-(y_dbl - mu_dbl) * sigma_inv_vec));
77+
= 1.0 - inv_logit((y_dbl - mu_dbl) * sigma_inv_vec);
7678
P += log(Pn);
7779

7880
if (!is_constant_all<T_y>::value) {

stan/math/prim/prob/logistic_lcdf.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/prim/fun/constants.hpp>
77
#include <stan/math/prim/fun/exp.hpp>
88
#include <stan/math/prim/fun/log.hpp>
9+
#include <stan/math/prim/fun/inv_logit.hpp>
910
#include <stan/math/prim/fun/scalar_seq_view.hpp>
1011
#include <stan/math/prim/fun/max_size.hpp>
1112
#include <stan/math/prim/fun/size.hpp>
@@ -71,8 +72,8 @@ return_type_t<T_y, T_loc, T_scale> logistic_lcdf(const T_y& y, const T_loc& mu,
7172
const T_partials_return sigma_dbl = sigma_vec.val(n);
7273
const T_partials_return sigma_inv_vec = 1.0 / sigma_vec.val(n);
7374

74-
const T_partials_return Pn
75-
= 1.0 / (1.0 + exp(-(y_dbl - mu_dbl) * sigma_inv_vec));
75+
// TODO(Andrew) Further simplify derivatives and log-scale below
76+
const T_partials_return Pn = inv_logit((y_dbl - mu_dbl) * sigma_inv_vec);
7677
P += log(Pn);
7778

7879
if (!is_constant_all<T_y>::value) {

0 commit comments

Comments
 (0)