Skip to content

Commit e9a3c56

Browse files
committed
refactor gamma_lccdf
1 parent 93d4750 commit e9a3c56

1 file changed

Lines changed: 95 additions & 76 deletions

File tree

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 95 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,85 @@
2525

2626
namespace stan {
2727
namespace math {
28+
namespace internal {
29+
template <typename T>
30+
struct Q_eval {
31+
T log_Q{0.0};
32+
T dlogQ_dalpha{0.0};
33+
bool ok{false};
34+
};
35+
36+
/**
37+
* Computes log q and d(log q) / d(alpha) using continued fraction.
38+
*/
39+
template <typename T, typename T_shape,
40+
bool any_fvar, bool partials_fvar>
41+
static inline Q_eval<T> eval_q_cf(const T& alpha,
42+
const T& beta_y) {
43+
Q_eval<T> out;
44+
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
45+
auto log_q_result = log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y));
46+
out.log_Q = log_q_result.log_q;
47+
out.dlogQ_dalpha = log_q_result.dlog_q_da;
48+
} else {
49+
out.log_Q = internal::log_q_gamma_cf(alpha, beta_y);
50+
if constexpr (is_autodiff_v<T_shape>) {
51+
if constexpr (!partials_fvar) {
52+
out.dlogQ_dalpha
53+
= grad_reg_inc_gamma(alpha, beta_y, tgamma(alpha),
54+
digamma(alpha)) / exp(out.log_Q);
55+
} else {
56+
T alpha_unit = alpha;
57+
alpha_unit.d_ = 1;
58+
T beta_y_unit = beta_y;
59+
beta_y_unit.d_ = 0;
60+
T log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit);
61+
out.dlogQ_dalpha = log_Q_fvar.d_;
62+
}
63+
}
64+
}
65+
66+
out.ok = std::isfinite(value_of_rec(out.log_Q));
67+
return out;
68+
}
69+
70+
/**
71+
* Computes log q and d(log q) / d(alpha) using log1m.
72+
*/
73+
template <typename T, typename T_shape,
74+
bool partials_fvar>
75+
static inline Q_eval<T> eval_q_log1m(const T& alpha,
76+
const T& beta_y) {
77+
Q_eval<T> out;
78+
out.log_Q = log1m(gamma_p(alpha, beta_y));
79+
80+
if (!std::isfinite(value_of_rec(out.log_Q))) {
81+
out.ok = false;
82+
return out;
83+
}
84+
85+
if constexpr (is_autodiff_v<T_shape>) {
86+
if constexpr (partials_fvar) {
87+
T alpha_unit = alpha;
88+
alpha_unit.d_ = 1;
89+
T beta_unit = beta_y;
90+
beta_unit.d_ = 0;
91+
T log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit));
92+
out.dlogQ_dalpha = log_Q_fvar.d_;
93+
} else {
94+
out.dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.log_Q);
95+
}
96+
}
97+
98+
out.ok = true;
99+
return out;
100+
}
101+
}
28102

29103
template <typename T_y, typename T_shape, typename T_inv_scale>
30-
inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
31-
const T_y& y, const T_shape& alpha, const T_inv_scale& beta) {
104+
inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(const T_y& y,
105+
const T_shape& alpha,
106+
const T_inv_scale& beta) {
32107
using std::exp;
33108
using std::log;
34109
using T_partials_return = partials_return_t<T_y, T_shape, T_inv_scale>;
@@ -81,91 +156,35 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
81156
return ops_partials.build(negative_infinity());
82157
}
83158

84-
bool use_cf = beta_y > alpha_dbl + 1.0;
85-
T_partials_return log_Qn;
86-
[[maybe_unused]] T_partials_return dlogQ_dalpha = 0.0;
87-
88-
// Branch by autodiff type first, then handle use_cf logic inside each path
89-
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
90-
// var-only path: use log_gamma_q_dgamma which computes both log_q
91-
// and its gradient analytically with double inputs
92-
const double beta_y_dbl = value_of_rec(beta_y);
93-
const double alpha_dbl_val = value_of_rec(alpha_dbl);
94-
95-
if (use_cf) {
96-
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
97-
log_Qn = log_q_result.log_q;
98-
dlogQ_dalpha = log_q_result.dlog_q_da;
99-
} else {
100-
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
101-
log_Qn = log1m(Pn);
102-
const T_partials_return Qn = exp(log_Qn);
103-
104-
// Check if we need to fallback to continued fraction
105-
bool need_cf_fallback
106-
= !std::isfinite(value_of_rec(log_Qn)) || Qn <= 0.0;
107-
if (need_cf_fallback && beta_y > 0.0) {
108-
auto log_q_result = log_gamma_q_dgamma(alpha_dbl_val, beta_y_dbl);
109-
log_Qn = log_q_result.log_q;
110-
dlogQ_dalpha = log_q_result.dlog_q_da;
111-
} else {
112-
dlogQ_dalpha = -grad_reg_lower_inc_gamma(alpha_dbl, beta_y) / Qn;
113-
}
114-
}
115-
} else if constexpr (partials_fvar && is_autodiff_v<T_shape>) {
116-
// fvar path: use unit derivative trick to compute gradients
117-
T_partials_return alpha_unit = alpha_dbl;
118-
alpha_unit.d_ = 1;
119-
T_partials_return beta_unit = beta_y;
120-
beta_unit.d_ = 0;
121-
122-
if (use_cf) {
123-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
124-
T_partials_return log_Qn_fvar
125-
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
126-
dlogQ_dalpha = log_Qn_fvar.d_;
127-
} else {
128-
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
129-
log_Qn = log1m(Pn);
130-
131-
if (!std::isfinite(value_of_rec(log_Qn)) && beta_y > 0.0) {
132-
// Fallback to continued fraction
133-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
134-
T_partials_return log_Qn_fvar
135-
= internal::log_q_gamma_cf(alpha_unit, beta_unit);
136-
dlogQ_dalpha = log_Qn_fvar.d_;
137-
} else {
138-
T_partials_return log_Qn_fvar = log1m(gamma_p(alpha_unit, beta_unit));
139-
dlogQ_dalpha = log_Qn_fvar.d_;
140-
}
141-
}
159+
const bool use_continued_fraction = beta_y > alpha_dbl + 1.0;
160+
internal::Q_eval<T_partials_return> result;
161+
if (use_continued_fraction) {
162+
result = internal::eval_q_cf<T_partials_return, T_shape,
163+
any_fvar, partials_fvar>(alpha_dbl, beta_y);
142164
} else {
143-
// No alpha derivative needed (alpha is constant or double-only)
144-
if (use_cf) {
145-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
146-
} else {
147-
const T_partials_return Pn = gamma_p(alpha_dbl, beta_y);
148-
log_Qn = log1m(Pn);
165+
result = internal::eval_q_log1m<T_partials_return, T_shape,
166+
partials_fvar>(alpha_dbl, beta_y);
149167

150-
if (!std::isfinite(value_of_rec(log_Qn)) && beta_y > 0.0) {
151-
log_Qn = internal::log_q_gamma_cf(alpha_dbl, beta_y);
152-
}
168+
if (!result.ok && beta_y > 0.0) {
169+
// Fallback to continued fraction if log1m fails
170+
result = internal::eval_q_cf<T_partials_return, T_shape,
171+
any_fvar, partials_fvar>(alpha_dbl, beta_y);
153172
}
154173
}
155-
if (!std::isfinite(value_of_rec(log_Qn))) {
174+
if (!result.ok) {
156175
return ops_partials.build(negative_infinity());
157176
}
158-
P += log_Qn;
177+
178+
P += result.log_Q;
159179

160180
if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
161181
const T_partials_return log_y = log(y_dbl);
162182
const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y);
163183

164-
const T_partials_return log_pdf = alpha_dbl * log(beta_dbl)
165-
- lgamma(alpha_dbl) + alpha_minus_one
166-
- beta_y;
184+
const T_partials_return log_pdf
185+
= alpha_dbl * log(beta_dbl) - lgamma(alpha_dbl) + alpha_minus_one - beta_y;
167186

168-
const T_partials_return hazard = exp(log_pdf - log_Qn); // f/Q
187+
const T_partials_return hazard = exp(log_pdf - result.log_Q); // f/Q
169188

170189
if constexpr (is_autodiff_v<T_y>) {
171190
partials<0>(ops_partials)[n] -= hazard;
@@ -175,7 +194,7 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
175194
}
176195
}
177196
if constexpr (is_autodiff_v<T_shape>) {
178-
partials<1>(ops_partials)[n] += dlogQ_dalpha;
197+
partials<1>(ops_partials)[n] += result.dlogQ_dalpha;
179198
}
180199
}
181200
return ops_partials.build(P);

0 commit comments

Comments
 (0)